Compare commits
508 Commits
experiment
...
maryhipp/b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fd3f62679 | ||
|
|
3dee9b21df | ||
|
|
56d56b5018 | ||
|
|
bc585ba964 | ||
|
|
f67bbadf83 | ||
|
|
e2942b9b8d | ||
|
|
0bf5fee1b2 | ||
|
|
8114fc7bc2 | ||
|
|
f9d2bcce04 | ||
|
|
84bf2a03e9 | ||
|
|
4ee65d179c | ||
|
|
44b6adfb9f | ||
|
|
466a819f06 | ||
|
|
e6fd1c3d1f | ||
|
|
7f6fdf5d39 | ||
|
|
40e6dd8464 | ||
|
|
79df46bad2 | ||
|
|
2f11936db0 | ||
|
|
2ba52b8921 | ||
|
|
fa3fcd7820 | ||
|
|
f45ea1145d | ||
|
|
5eb6148336 | ||
|
|
49892faee4 | ||
|
|
7bb876a79b | ||
|
|
f89be8c685 | ||
|
|
7e4009a58e | ||
|
|
5141e82f88 | ||
|
|
8277bfab5e | ||
|
|
0af8a0e84b | ||
|
|
9bafe4a94f | ||
|
|
111322b015 | ||
|
|
859c155e7f | ||
|
|
955fef35aa | ||
|
|
f3b293b5cc | ||
|
|
6efa953172 | ||
|
|
06ac16a77d | ||
|
|
05c939d41e | ||
|
|
cfee02b753 | ||
|
|
4f088252db | ||
|
|
ca3e826a14 | ||
|
|
0cb886b915 | ||
|
|
2ec8fd3dc7 | ||
|
|
90abd0fe49 | ||
|
|
3651cf7ee2 | ||
|
|
8eca3bbbcd | ||
|
|
73318c2847 | ||
|
|
6d10e40c9b | ||
|
|
5cf9b75d77 | ||
|
|
d4463674cf | ||
|
|
ce7172d78c | ||
|
|
38b2dedc1d | ||
|
|
cd73085eb9 | ||
|
|
2497aa5cd8 | ||
|
|
089ada8cd1 | ||
|
|
35d14fc0f9 | ||
|
|
b79bca2c14 | ||
|
|
5fc60d0539 | ||
|
|
7b97754271 | ||
|
|
d3c177aaef | ||
|
|
3f7ac556c6 | ||
|
|
8087b428cc | ||
|
|
0c639bd751 | ||
|
|
be6ba57775 | ||
|
|
2f8d3022a0 | ||
|
|
4da861e980 | ||
|
|
9d7dfeb857 | ||
|
|
572e6b892a | ||
|
|
7b2079cf83 | ||
|
|
535eb1db16 | ||
|
|
01738deb23 | ||
|
|
fbff22c94b | ||
|
|
5c305b1eeb | ||
|
|
990b6b5f6a | ||
|
|
2dfcba8654 | ||
|
|
d95773f50f | ||
|
|
6d111aac90 | ||
|
|
f9fc89b3c5 | ||
|
|
ab76d54c10 | ||
|
|
56245a7406 | ||
|
|
bf04e913c2 | ||
|
|
cdc49456e8 | ||
|
|
37dc2d9d4d | ||
|
|
6e1ddb671e | ||
|
|
496a2db15c | ||
|
|
5292eda0e4 | ||
|
|
4ac41bc4b1 | ||
|
|
4be4fc6731 | ||
|
|
a9fdc77edd | ||
|
|
385765faec | ||
|
|
adb05cde5b | ||
|
|
211e8203f8 | ||
|
|
0b9ae74192 | ||
|
|
165c57c001 | ||
|
|
2514af79a0 | ||
|
|
f952f8f685 | ||
|
|
484b572023 | ||
|
|
cd9baf8092 | ||
|
|
81385d7d35 | ||
|
|
519bcb38c1 | ||
|
|
567d46b646 | ||
|
|
030802295b | ||
|
|
a495c8c156 | ||
|
|
ae6db67068 | ||
|
|
3d84e7756a | ||
|
|
98431b3de4 | ||
|
|
210a3f9aa7 | ||
|
|
9332ce639c | ||
|
|
84cf8bdc08 | ||
|
|
64a6aa0293 | ||
|
|
5ae14bffba | ||
|
|
0909812c84 | ||
|
|
66c0aea9e7 | ||
|
|
2bcded78e1 | ||
|
|
beb3e5aeb7 | ||
|
|
5b6069b916 | ||
|
|
766cb887e4 | ||
|
|
ef317be1f9 | ||
|
|
027b84d1aa | ||
|
|
11b670755d | ||
|
|
a536719fc3 | ||
|
|
8e6d88e98c | ||
|
|
0f1b975d0e | ||
|
|
2fef478497 | ||
|
|
6df6abf6f6 | ||
|
|
1b70bd1380 | ||
|
|
d1d2d5a47d | ||
|
|
c96ae4c331 | ||
|
|
ce465acf04 | ||
|
|
33ee418d8c | ||
|
|
4f1008f31f | ||
|
|
6cc629e19d | ||
|
|
537ae2f901 | ||
|
|
f6db9da06c | ||
|
|
a17dbd7df6 | ||
|
|
98a4cc20a9 | ||
|
|
e2bdcc0271 | ||
|
|
ffd0f5924b | ||
|
|
498d2ecc2b | ||
|
|
4ebe839d54 | ||
|
|
bc16b50302 | ||
|
|
4267132926 | ||
|
|
b69f26c85c | ||
|
|
832335998f | ||
|
|
1102c12084 | ||
|
|
b5cee7d20c | ||
|
|
23b4e1cea0 | ||
|
|
635a814dfb | ||
|
|
c19835c2d0 | ||
|
|
ed38eaa10c | ||
|
|
842eb4bb0a | ||
|
|
89b82b3dc4 | ||
|
|
8923201fdf | ||
|
|
226409107b | ||
|
|
ae986bf873 | ||
|
|
503e3bca54 | ||
|
|
daf75a1361 | ||
|
|
fe4b2d53ed | ||
|
|
c39f8b478b | ||
|
|
1f82d8013e | ||
|
|
e373bfca54 | ||
|
|
2ca8611723 | ||
|
|
b12cf315a8 | ||
|
|
975586bb40 | ||
|
|
a7ba142ad9 | ||
|
|
0d36bab6cc | ||
|
|
c2e7f62701 | ||
|
|
1f194e3688 | ||
|
|
f9b8b5cff2 | ||
|
|
f7c92e1eff | ||
|
|
70b8c3dfea | ||
|
|
43b30355e4 | ||
|
|
a93bd01353 | ||
|
|
bb1b8ceaa8 | ||
|
|
be8edaf3fd | ||
|
|
9cbaefaa81 | ||
|
|
cc7c6e5d41 | ||
|
|
f2ee8a3da8 | ||
|
|
e98d7a52d4 | ||
|
|
21e1c0a5f0 | ||
|
|
611e241ca7 | ||
|
|
6df4af2c79 | ||
|
|
0f8606914e | ||
|
|
5b1099193d | ||
|
|
230131646f | ||
|
|
8b1ec2685f | ||
|
|
60c2c877d7 | ||
|
|
315a056686 | ||
|
|
80b0c5eab4 | ||
|
|
08dc265e09 | ||
|
|
029a95550e | ||
|
|
ee6a26a97d | ||
|
|
a512fdc0f6 | ||
|
|
767a612746 | ||
|
|
0a71d6baa1 | ||
|
|
37be827e17 | ||
|
|
04a9894e77 | ||
|
|
f9958de6be | ||
|
|
ec10aca91e | ||
|
|
2b7dd3e236 | ||
|
|
fa884134d9 | ||
|
|
18006cab9a | ||
|
|
75ea716c13 | ||
|
|
d5f7027597 | ||
|
|
b1ad777f5a | ||
|
|
f65c8092cb | ||
|
|
94bfef3543 | ||
|
|
c48fd9c083 | ||
|
|
f49fc7fb55 | ||
|
|
a4b029d03c | ||
|
|
d6c9bf5b38 | ||
|
|
4f82273fc4 | ||
|
|
e54355f0f3 | ||
|
|
b2934be6ba | ||
|
|
eab67b6a01 | ||
|
|
02fa116690 | ||
|
|
5190a4c282 | ||
|
|
141d438517 | ||
|
|
549d2e0485 | ||
|
|
d3d8b71c67 | ||
|
|
6eaaa75a5d | ||
|
|
ba57ec5907 | ||
|
|
b524bf3c04 | ||
|
|
cd0e4bc1d7 | ||
|
|
9d3cd85bdd | ||
|
|
46a8eed33e | ||
|
|
9fee3f7b66 | ||
|
|
9217a217d4 | ||
|
|
b2700ffde4 | ||
|
|
511da59793 | ||
|
|
409e5d01ba | ||
|
|
58d5c61c79 | ||
|
|
3d8da67be3 | ||
|
|
957ee6d370 | ||
|
|
fecad2c014 | ||
|
|
550e6ef27a | ||
|
|
cc85c98bf3 | ||
|
|
75fb3f429f | ||
|
|
d63bb39475 | ||
|
|
096333ba3f | ||
|
|
0b2925709c | ||
|
|
7a8f14d595 | ||
|
|
59ba9fc0f6 | ||
|
|
6e0beb1ed4 | ||
|
|
94636ddb03 | ||
|
|
746e099f0d | ||
|
|
499e89d6f6 | ||
|
|
250d530260 | ||
|
|
90fa3eebb3 | ||
|
|
0aba105a8f | ||
|
|
9e2e82a752 | ||
|
|
561951ad98 | ||
|
|
3ff9961bda | ||
|
|
33779b6339 | ||
|
|
b35cdc05a5 | ||
|
|
9afb5d6ace | ||
|
|
50177b8ed9 | ||
|
|
c8864e475b | ||
|
|
fcf7f4ac77 | ||
|
|
29f1c6dc82 | ||
|
|
28208e6f49 | ||
|
|
c33acf951e | ||
|
|
500cd552bc | ||
|
|
55d27f71a3 | ||
|
|
746c7c59ff | ||
|
|
ad96c41156 | ||
|
|
27bd127fb0 | ||
|
|
f296e5c41e | ||
|
|
a67d8376c7 | ||
|
|
9f6221fe8c | ||
|
|
7587b54787 | ||
|
|
7254ffc3e7 | ||
|
|
6034fa12de | ||
|
|
ce3675fc14 | ||
|
|
8acd7eeca5 | ||
|
|
7293a6036a | ||
|
|
0b11f309ca | ||
|
|
6a8eb392b2 | ||
|
|
f343ab0302 | ||
|
|
824ca92760 | ||
|
|
d7d6298ec0 | ||
|
|
58a48bf197 | ||
|
|
5629d8fa37 | ||
|
|
1affb7f647 | ||
|
|
69a9dc7b36 | ||
|
|
f3ae52ff97 | ||
|
|
7479f9cc02 | ||
|
|
87ce4ab27c | ||
|
|
7c0023ad9e | ||
|
|
231e665675 | ||
|
|
80fd4c2176 | ||
|
|
3b6e425e17 | ||
|
|
50415450d8 | ||
|
|
06296896a9 | ||
|
|
a7399aca0c | ||
|
|
d1ea8b1e98 | ||
|
|
f851ad7ba0 | ||
|
|
591838a84b | ||
|
|
c0c2ab3dcf | ||
|
|
56023bc725 | ||
|
|
2ef6a8995b | ||
|
|
d0fee93aac | ||
|
|
1bfe9835cf | ||
|
|
8e7eae6cc7 | ||
|
|
f6522c8971 | ||
|
|
a969707e45 | ||
|
|
6c8e898f09 | ||
|
|
7bad9bcf53 | ||
|
|
d42b45116f | ||
|
|
d4812bbc8d | ||
|
|
3cd05cf6bf | ||
|
|
2564301aeb | ||
|
|
da0efeaa7f | ||
|
|
49cce1eec6 | ||
|
|
e9ec5ab85c | ||
|
|
17fed1c870 | ||
|
|
ade78b9591 | ||
|
|
c8fbaf54b6 | ||
|
|
f86d388786 | ||
|
|
cd2c688562 | ||
|
|
2d29ac6f0d | ||
|
|
2c2b731386 | ||
|
|
2f68a1a76c | ||
|
|
930e7bc754 | ||
|
|
7d4ace962a | ||
|
|
06842f8e0a | ||
|
|
c82da330db | ||
|
|
628df4ec98 | ||
|
|
16b956616f | ||
|
|
604cc17a3a | ||
|
|
37c9b85549 | ||
|
|
8b39b67ec7 | ||
|
|
a933977861 | ||
|
|
dfb41d8461 | ||
|
|
e98f7eda2e | ||
|
|
b4a74f6523 | ||
|
|
f7aec3b934 | ||
|
|
4d5169e16d | ||
|
|
a7e44678fb | ||
|
|
da0184a786 | ||
|
|
f56f19710d | ||
|
|
96b7248051 | ||
|
|
e77400ab62 | ||
|
|
13347f6aec | ||
|
|
a9bf387e5e | ||
|
|
8258c87a9f | ||
|
|
1b1b399fd0 | ||
|
|
a8d3e078c0 | ||
|
|
6ed7ba57dd | ||
|
|
2b3b77a276 | ||
|
|
8b8ec68b30 | ||
|
|
e20af5aef0 | ||
|
|
57e8ec9488 | ||
|
|
734a9e4271 | ||
|
|
fe924daee3 | ||
|
|
750f09fbed | ||
|
|
4df581811e | ||
|
|
eb70bc2ae4 | ||
|
|
5f29526a8e | ||
|
|
492bfe002a | ||
|
|
809705c30d | ||
|
|
f0918edf98 | ||
|
|
a846d82fa1 | ||
|
|
22f7cf0638 | ||
|
|
25c669b1d6 | ||
|
|
4367061b19 | ||
|
|
0fd13d3604 | ||
|
|
72a3e776b2 | ||
|
|
af044007d5 | ||
|
|
1db2c93f75 | ||
|
|
f272a44feb | ||
|
|
2539e26c18 | ||
|
|
b0738b7f70 | ||
|
|
8469d3e95a | ||
|
|
ae17d01e1d | ||
|
|
f3d3316558 | ||
|
|
5a6cefb0ea | ||
|
|
1a6f5f0860 | ||
|
|
5bfd6cb66f | ||
|
|
59caff7ff0 | ||
|
|
6487e7d906 | ||
|
|
77033eabd3 | ||
|
|
b80abdd101 | ||
|
|
006d782cc8 | ||
|
|
d09dfc3e9b | ||
|
|
66f524cae7 | ||
|
|
9ba50130a1 | ||
|
|
d4cf2d2666 | ||
|
|
9aaf67c5b4 | ||
|
|
b8b589c150 | ||
|
|
d93900a8de | ||
|
|
7f4c387080 | ||
|
|
80876bbbd1 | ||
|
|
7a4ff4c089 | ||
|
|
44bf308192 | ||
|
|
12e51c84ae | ||
|
|
b2eb83deff | ||
|
|
0ccc3b509e | ||
|
|
4043a4c21c | ||
|
|
c8ceb96091 | ||
|
|
83f75750a9 | ||
|
|
dc96a3e79d | ||
|
|
c076f1397e | ||
|
|
2568aafc0b | ||
|
|
65ed224bfc | ||
|
|
b6e369c745 | ||
|
|
ecabfc252b | ||
|
|
da96a41103 | ||
|
|
d162b78767 | ||
|
|
eb6c317f04 | ||
|
|
6d7223238f | ||
|
|
8607d124c5 | ||
|
|
23497bf759 | ||
|
|
b10cf20eb1 | ||
|
|
3d93851dba | ||
|
|
9bacd77a79 | ||
|
|
1b158f62c4 | ||
|
|
6ad565d84c | ||
|
|
04229082d6 | ||
|
|
03c27412f7 | ||
|
|
f0613bb0ef | ||
|
|
0e9f92b868 | ||
|
|
7d0cc6ec3f | ||
|
|
2f8b928486 | ||
|
|
0d3c27f46c | ||
|
|
cff91f06d3 | ||
|
|
1d5d187ba1 | ||
|
|
1ac14a1e43 | ||
|
|
cfc3a20565 | ||
|
|
05ae4e283c | ||
|
|
f06fee4581 | ||
|
|
9091e19de8 | ||
|
|
0a0b7141af | ||
|
|
1deca89fde | ||
|
|
446fb4a438 | ||
|
|
ab5d938a1d | ||
|
|
9942af756a | ||
|
|
06742faca7 | ||
|
|
d2bddf7f91 | ||
|
|
91ebf9f76e | ||
|
|
bf94412d14 | ||
|
|
e080fd1e08 | ||
|
|
eeef1e08f8 | ||
|
|
b3b94b5a8d | ||
|
|
5c9787c145 | ||
|
|
cf72eba15c | ||
|
|
a6f9396a30 | ||
|
|
118d5b387b | ||
|
|
02d2cc758d | ||
|
|
db545f8801 | ||
|
|
b0d72b15b3 | ||
|
|
4e0949fa55 | ||
|
|
f028342f5b | ||
|
|
7021467048 | ||
|
|
26ef5249b1 | ||
|
|
87424be95d | ||
|
|
366952f810 | ||
|
|
450e95de59 | ||
|
|
0ba8a0ea6c | ||
|
|
f4981f26d5 | ||
|
|
6bc21984c6 | ||
|
|
43d6312587 | ||
|
|
0d125bf3e4 | ||
|
|
921ccad04d | ||
|
|
05c9207e7b | ||
|
|
3fc789a7ee | ||
|
|
008362918e | ||
|
|
8fc75a71ee | ||
|
|
82d259f43b | ||
|
|
ec48779080 | ||
|
|
bc20fe4cb5 | ||
|
|
5de42be4a6 | ||
|
|
818c55cd53 | ||
|
|
0db1e97119 | ||
|
|
29ac252501 | ||
|
|
880727436c | ||
|
|
77c5c18542 | ||
|
|
ed76250dba | ||
|
|
4d22cafdad | ||
|
|
1f9e984b0d | ||
|
|
8a4e5f73aa | ||
|
|
4599575e65 | ||
|
|
242d860a47 | ||
|
|
0c1a7e72d4 | ||
|
|
11a44b944d | ||
|
|
fd7b842419 | ||
|
|
5998509888 | ||
|
|
7292d89108 | ||
|
|
437f45a97f | ||
|
|
13ef33ed64 | ||
|
|
86d8b46fca | ||
|
|
df53b62048 | ||
|
|
55d3f04476 | ||
|
|
72ebe2ce68 | ||
|
|
7cd8b2f207 | ||
|
|
bacdf985f1 | ||
|
|
e3519052ae | ||
|
|
adfd1e52f4 | ||
|
|
0e48c98330 | ||
|
|
ff1c40747e | ||
|
|
dbfd1bcb5e | ||
|
|
ccceb32a85 | ||
|
|
21617e60e1 | ||
|
|
35dd58e273 | ||
|
|
86b8b69e88 | ||
|
|
bc9a5038fd | ||
|
|
b163ae6a4d | ||
|
|
dca685ac25 | ||
|
|
e70bedba7d |
16
.github/workflows/style-checks.yml
vendored
@@ -1,13 +1,14 @@
|
||||
name: Black # TODO: add isort and flake8 later
|
||||
name: style checks
|
||||
# just formatting and flake8 for now
|
||||
# TODO: add isort later
|
||||
|
||||
on:
|
||||
pull_request: {}
|
||||
pull_request:
|
||||
push:
|
||||
branches: master
|
||||
tags: "*"
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
black:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -19,9 +20,8 @@ jobs:
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install --upgrade pip wheel
|
||||
pip install .[test]
|
||||
pip install black flake8 Flake8-pyproject
|
||||
|
||||
# - run: isort --check-only .
|
||||
- run: black --check .
|
||||
# - run: flake8
|
||||
- run: flake8
|
||||
|
||||
50
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -1,50 +0,0 @@
|
||||
name: Test invoke.py pip
|
||||
|
||||
# This is a dummy stand-in for the actual tests
|
||||
# we don't need to run python tests on non-Python changes
|
||||
# But PRs require passing tests to be mergeable
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**'
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- '!tests/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: skip
|
||||
run: echo "no build required"
|
||||
24
.github/workflows/test-invoke-pip.yml
vendored
@@ -3,16 +3,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- 'tests/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -65,10 +56,23 @@ jobs:
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v37
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -76,6 +80,7 @@ jobs:
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
@@ -83,6 +88,7 @@ jobs:
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
|
||||
37
.gitignore
vendored
@@ -1,23 +1,8 @@
|
||||
# ignore default image save location and model symbolic link
|
||||
.idea/
|
||||
embeddings/
|
||||
outputs/
|
||||
models/ldm/stable-diffusion-v1/model.ckpt
|
||||
**/restoration/codeformer/weights
|
||||
|
||||
# ignore user models config
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
invokeai.init
|
||||
.version
|
||||
.last_model
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
|
||||
# ignore a directory which serves as a place for initial images
|
||||
inputs/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@@ -189,39 +174,17 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
src
|
||||
**/__pycache__/
|
||||
outputs
|
||||
|
||||
# Logs and associated folders
|
||||
# created from generated embeddings.
|
||||
logs
|
||||
testtube
|
||||
checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/yarn.lock
|
||||
invokeai/frontend/node_modules
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
.vscode/
|
||||
gfpgan/
|
||||
models/ldm/stable-diffusion-v1/*.sha256
|
||||
|
||||
|
||||
# GFPGAN model files
|
||||
gfpgan/
|
||||
|
||||
# config file (will be created by installer)
|
||||
configs/models.yaml
|
||||
|
||||
# ignore initfile
|
||||
.invokeai
|
||||
|
||||
# ignore environment.yml and requirements.txt
|
||||
# these are links to the real files in environments-and-requirements
|
||||
|
||||
@@ -8,3 +8,10 @@ repos:
|
||||
language: system
|
||||
entry: black
|
||||
types: [python]
|
||||
|
||||
- id: flake8
|
||||
name: flake8
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: flake8
|
||||
types: [python]
|
||||
|
||||
42
README.md
@@ -43,7 +43,7 @@ Web Interface, interactive Command Line Interface, and also serves as
|
||||
the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to
|
||||
Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a
|
||||
Install](https://invoke-ai.github.io/InvokeAI/installation/INSTALLATION/)] [<a
|
||||
href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a
|
||||
href="https://invoke-ai.github.io/InvokeAI/">Documentation and
|
||||
Tutorials</a>] [<a
|
||||
@@ -81,7 +81,7 @@ Table of Contents 📝
|
||||
## Quick Start
|
||||
|
||||
For full installation and upgrade instructions, please see:
|
||||
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
|
||||
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/INSTALLATION/)
|
||||
|
||||
If upgrading from version 2.3, please read [Migrating a 2.3 root
|
||||
directory to 3.0](#migrating-to-3) first.
|
||||
@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
@@ -184,8 +184,9 @@ the command `npm install -g yarn` if needed)
|
||||
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
invokeai-configure --root .
|
||||
```
|
||||
Don't miss the dot at the end!
|
||||
|
||||
7. Launch the web server (do it every time you run InvokeAI):
|
||||
|
||||
@@ -193,15 +194,9 @@ the command `npm install -g yarn` if needed)
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
8. Build Node.js assets
|
||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
|
||||
```terminal
|
||||
cd invokeai/frontend/web/
|
||||
yarn vite build
|
||||
```
|
||||
|
||||
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
|
||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||
@@ -311,13 +306,30 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
|
||||
You may now launch the WebUI in the usual way, by selecting option [1]
|
||||
from the launcher script
|
||||
|
||||
#### Migration Caveats
|
||||
#### Migrating Images
|
||||
|
||||
The migration script will migrate your invokeai settings and models,
|
||||
including textual inversion models, LoRAs and merges that you may have
|
||||
installed previously. However it does **not** migrate the generated
|
||||
images stored in your 2.3-format outputs directory. You will need to
|
||||
manually import selected images into the 3.0 gallery via drag-and-drop.
|
||||
images stored in your 2.3-format outputs directory. To do this, you
|
||||
need to run an additional step:
|
||||
|
||||
1. From a working InvokeAI 3.0 root directory, start the launcher and
|
||||
enter menu option [8] to open the "developer's console".
|
||||
|
||||
2. At the developer's console command line, type the command:
|
||||
|
||||
```bash
|
||||
invokeai-import-images
|
||||
```
|
||||
|
||||
3. This will lead you through the process of confirming the desired
|
||||
source and destination for the imported images. The images will
|
||||
appear in the gallery board of your choice, and contain the
|
||||
original prompt, model name, and other parameters used to generate
|
||||
the image.
|
||||
|
||||
(Many kudos to **techjedi** for contributing this script.)
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ configure() {
|
||||
echo "To reconfigure InvokeAI, delete the above file."
|
||||
echo "======================================================================"
|
||||
else
|
||||
mkdir -p ${INVOKEAI_ROOT}
|
||||
chown --recursive ${USER} ${INVOKEAI_ROOT}
|
||||
mkdir -p "${INVOKEAI_ROOT}"
|
||||
chown --recursive ${USER} "${INVOKEAI_ROOT}"
|
||||
gosu ${USER} invokeai-configure --yes --default_only
|
||||
fi
|
||||
}
|
||||
@@ -50,16 +50,16 @@ fi
|
||||
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
|
||||
apt-get update
|
||||
apt-get install -y openssh-server
|
||||
pushd $HOME
|
||||
pushd "$HOME"
|
||||
mkdir -p .ssh
|
||||
echo ${PUBLIC_KEY} > .ssh/authorized_keys
|
||||
echo "${PUBLIC_KEY}" > .ssh/authorized_keys
|
||||
chmod -R 700 .ssh
|
||||
popd
|
||||
service ssh start
|
||||
fi
|
||||
|
||||
|
||||
cd ${INVOKEAI_ROOT}
|
||||
cd "${INVOKEAI_ROOT}"
|
||||
|
||||
# Run the CMD as the Container User (not root).
|
||||
exec gosu ${USER} "$@"
|
||||
|
||||
|
Before Width: | Height: | Size: 310 KiB After Width: | Height: | Size: 297 KiB |
|
Before Width: | Height: | Size: 335 KiB After Width: | Height: | Size: 319 KiB |
|
Before Width: | Height: | Size: 179 KiB After Width: | Height: | Size: 197 KiB |
|
Before Width: | Height: | Size: 501 KiB After Width: | Height: | Size: 421 KiB |
|
Before Width: | Height: | Size: 473 KiB After Width: | Height: | Size: 585 KiB |
|
Before Width: | Height: | Size: 557 KiB After Width: | Height: | Size: 598 KiB |
|
Before Width: | Height: | Size: 340 KiB After Width: | Height: | Size: 438 KiB |
@@ -14,11 +14,14 @@ To join, just raise your hand on the InvokeAI Discord server (#dev-chat) or the
|
||||
#### Development
|
||||
If you’d like to help with development, please see our [development guide](contribution_guides/development.md). If you’re unfamiliar with contributing to open source projects, there is a tutorial contained within the development guide.
|
||||
|
||||
#### Nodes
|
||||
If you’d like to help with development, please see our [nodes contribution guide](/nodes/contributingNodes). If you’re unfamiliar with contributing to open source projects, there is a tutorial contained within the development guide.
|
||||
|
||||
#### Documentation
|
||||
If you’d like to help with documentation, please see our [documentation guide](contribution_guides/documenation.md).
|
||||
If you’d like to help with documentation, please see our [documentation guide](contribution_guides/documentation.md).
|
||||
|
||||
#### Translation
|
||||
If you'd like to help with translation, please see our [translation guide](docs/contributing/.contribution_guides/translation.md).
|
||||
If you'd like to help with translation, please see our [translation guide](contribution_guides/translation.md).
|
||||
|
||||
#### Tutorials
|
||||
Please reach out to @imic or @hipsterusername on [Discord](https://discord.gg/ZmtBAhwWhy) to help create tutorials for InvokeAI.
|
||||
|
||||
@@ -270,9 +270,12 @@ new Invocation ready to be used.
|
||||
|
||||

|
||||
|
||||
# Advanced
|
||||
## Contributing Nodes
|
||||
Once you've created a Node, the next step is to share it with the community! The best way to do this is to submit a Pull Request to add the Node to the [Community Nodes](nodes/communityNodes) list. If you're not sure how to do that, take a look a at our [contributing nodes overview](contributingNodes).
|
||||
|
||||
## Custom Input Fields
|
||||
## Advanced
|
||||
|
||||
### Custom Input Fields
|
||||
|
||||
Now that you know how to create your own Invocations, let us dive into slightly
|
||||
more advanced topics.
|
||||
@@ -352,7 +355,7 @@ input field.
|
||||
We will discuss the `Config` class in extra detail later in this guide and how
|
||||
you can use it to make your Invocations more robust.
|
||||
|
||||
## Custom Output Types
|
||||
### Custom Output Types
|
||||
|
||||
Like with custom inputs, sometimes you might find yourself needing custom
|
||||
outputs that InvokeAI does not provide. We can easily set one up.
|
||||
@@ -396,7 +399,7 @@ All set. We now have an output type that requires what we need to create a
|
||||
blank_image. And if you noticed it, we even used the `Config` class to ensure
|
||||
the fields are required.
|
||||
|
||||
## Custom Configuration
|
||||
### Custom Configuration
|
||||
|
||||
As you might have noticed when making inputs and outputs, we used a class called
|
||||
`Config` from _pydantic_ to further customize them. Because our inputs and
|
||||
@@ -492,7 +495,7 @@ later time.
|
||||
|
||||
# **[TODO]**
|
||||
|
||||
## Custom Components For Frontend
|
||||
### Custom Components For Frontend
|
||||
|
||||
Every backend input type should have a corresponding frontend component so the
|
||||
UI knows what to render when you use a particular field type.
|
||||
@@ -513,7 +516,7 @@ now.
|
||||
|
||||
---
|
||||
|
||||
# OLD -- TO BE DELETED OR MOVED LATER
|
||||
<!-- # OLD -- TO BE DELETED OR MOVED LATER
|
||||
|
||||
---
|
||||
|
||||
@@ -787,4 +790,5 @@ With the customization in place, the schema will now show these properties as
|
||||
required, obviating the need for extensive null checks in client code.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
<https://github.com/pydantic/pydantic/discussions/4577> -->
|
||||
|
||||
|
||||
@@ -175,22 +175,27 @@ These configuration settings allow you to enable and disable various InvokeAI fe
|
||||
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
|
||||
|
||||
### Memory/Performance
|
||||
### Generation
|
||||
|
||||
These options tune InvokeAI's memory and performance characteristics.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
|
||||
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
|
||||
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
|
||||
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
|
||||
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
|
||||
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
|
||||
| Setting | Default Value | Description |
|
||||
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
|
||||
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
|
||||
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
|
||||
|
||||
### Device
|
||||
|
||||
These options configure the generation execution device.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
|
||||
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||
|
||||
|
||||
### Paths
|
||||
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
# Nodes Editor (Experimental)
|
||||
|
||||
🚨
|
||||
*The node editor is experimental. We've made it accessible because we use it to develop the application, but we have not addressed the many known rough edges. It's very easy to shoot yourself in the foot, and we cannot offer support for it until it sees full release (ETA v3.1). Everything is subject to change without warning.*
|
||||
🚨
|
||||
|
||||
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. The node processing flow is usually done from left (inputs) to right (outputs), though linearity can become abstracted the more complex the node graph becomes. Nodes inputs and outputs are connected by dragging connectors from node to node.
|
||||
|
||||
To better understand how nodes are used, think of how an electric power bar works. It takes in one input (electricity from a wall outlet) and passes it to multiple devices through multiple outputs. Similarly, a node could have multiple inputs and outputs functioning at the same (or different) time, but all node outputs pass information onward like a power bar passes electricity. Not all outputs are compatible with all inputs, however - Each node has different constraints on how it is expecting to input/output information. In general, node outputs are colour-coded to match compatible inputs of other nodes.
|
||||
|
||||
## Anatomy of a Node
|
||||
|
||||
Individual nodes are made up of the following:
|
||||
|
||||
- Inputs: Edge points on the left side of the node window where you connect outputs from other nodes.
|
||||
- Outputs: Edge points on the right side of the node window where you connect to inputs on other nodes.
|
||||
- Options: Various options which are either manually configured, or overridden by connecting an output from another node to the input.
|
||||
|
||||
## Diffusion Overview
|
||||
|
||||
Taking the time to understand the diffusion process will help you to understand how to set up your nodes in the nodes editor.
|
||||
|
||||
There are two main spaces Stable Diffusion works in: image space and latent space.
|
||||
|
||||
Image space represents images in pixel form that you look at. Latent space represents compressed inputs. It’s in latent space that Stable Diffusion processes images. A VAE (Variational Auto Encoder) is responsible for compressing and encoding inputs into latent space, as well as decoding outputs back into image space.
|
||||
|
||||
When you generate an image using text-to-image, multiple steps occur in latent space:
|
||||
1. Random noise is generated at the chosen height and width. The noise’s characteristics are dictated by the chosen (or not chosen) seed. This noise tensor is passed into latent space. We’ll call this noise A.
|
||||
1. Using a model’s U-Net, a noise predictor examines noise A, and the words tokenized by CLIP from your prompt (conditioning). It generates its own noise tensor to predict what the final image might look like in latent space. We’ll call this noise B.
|
||||
1. Noise B is subtracted from noise A in an attempt to create a final latent image indicative of the inputs. This step is repeated for the number of sampler steps chosen.
|
||||
1. The VAE decodes the final latent image from latent space into image space.
|
||||
|
||||
image-to-image is a similar process, with only step 1 being different:
|
||||
1. The input image is decoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how much noise is added, 0 being none, and 1 being all-encompassing. We’ll call this noise A. The process is then the same as steps 2-4 in the text-to-image explanation above.
|
||||
|
||||
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).
|
||||
|
||||
A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from the latent image across the sampler steps chosen (step 3 above). Less noise is usually subtracted at higher sampler steps.
|
||||
|
||||
## Node Types (Base Nodes)
|
||||
|
||||
| Node <img width=160 align="right"> | Function |
|
||||
| ---------------------------------- | --------------------------------------------------------------------------------------|
|
||||
| Add | Adds two numbers |
|
||||
| CannyImageProcessor | Canny edge detection for ControlNet |
|
||||
| ClipSkip | Skip layers in clip text_encoder model |
|
||||
| Collect | Collects values into a collection |
|
||||
| Prompt (Compel) | Parse prompt using compel package to conditioning |
|
||||
| ContentShuffleImageProcessor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| CvInpaint | Simple inpaint using opencv |
|
||||
| Divide | Divides two numbers |
|
||||
| DynamicPrompt | Parses a prompt using adieyal/dynamic prompt's random or combinatorial generator |
|
||||
| FloatLinearRange | Creates a range |
|
||||
| HedImageProcessor | Applies HED edge detection to image |
|
||||
| ImageBlur | Blurs an image |
|
||||
| ImageChannel | Gets a channel from an image |
|
||||
| ImageCollection | Load a collection of images and provide it as output |
|
||||
| ImageConvert | Converts an image to a different mode |
|
||||
| ImageCrop | Crops an image to a specified box. The box can be outside of the image. |
|
||||
| ImageInverseLerp | Inverse linear interpolation of all pixels of an image |
|
||||
| ImageLerp | Linear interpolation of all pixels of an image |
|
||||
| ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` |
|
||||
| ImageNSFWBlurInvocation | Detects and blurs images that may contain sexually explicit content |
|
||||
| ImagePaste | Pastes an image into another image |
|
||||
| ImageProcessor | Base class for invocations that reprocess images for ControlNet |
|
||||
| ImageResize | Resizes an image to specific dimensions |
|
||||
| ImageScale | Scales an image by a factor |
|
||||
| ImageToLatents | Scales latents by a given factor |
|
||||
| ImageWatermarkInvocation | Adds an invisible watermark to images |
|
||||
| InfillColor | Infills transparent areas of an image with a solid color |
|
||||
| InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm |
|
||||
| InfillTile | Infills transparent areas of an image with tiles of the image |
|
||||
| Inpaint | Generates an image using inpaint |
|
||||
| Iterate | Iterates over a list of items |
|
||||
| LatentsToImage | Generates an image from latents |
|
||||
| LatentsToLatents | Generates latents using latents as base image |
|
||||
| LeresImageProcessor | Applies leres processing to image |
|
||||
| LineartAnimeImageProcessor | Applies line art anime processing to image |
|
||||
| LineartImageProcessor | Applies line art processing to image |
|
||||
| LoadImage | Load an image and provide it as output |
|
||||
| Lora Loader | Apply selected lora to unet and text_encoder |
|
||||
| Model Loader | Loads a main model, outputting its submodels |
|
||||
| MaskFromAlpha | Extracts the alpha channel of an image as a mask |
|
||||
| MediapipeFaceProcessor | Applies mediapipe face processing to image |
|
||||
| MidasDepthImageProcessor | Applies Midas depth processing to image |
|
||||
| MlsdImageProcessor | Applied MLSD processing to image |
|
||||
| Multiply | Multiplies two numbers |
|
||||
| Noise | Generates latent noise |
|
||||
| NormalbaeImageProcessor | Applies NormalBAE processing to image |
|
||||
| OpenposeImageProcessor | Applies Openpose processing to image |
|
||||
| ParamFloat | A float parameter |
|
||||
| ParamInt | An integer parameter |
|
||||
| PidiImageProcessor | Applies PIDI processing to an image |
|
||||
| Progress Image | Displays the progress image in the Node Editor |
|
||||
| RandomInit | Outputs a single random integer |
|
||||
| RandomRange | Creates a collection of random numbers |
|
||||
| Range | Creates a range of numbers from start to stop with step |
|
||||
| RangeOfSize | Creates a range from start to start + size with step |
|
||||
| ResizeLatents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8. |
|
||||
| RestoreFace | Restores faces in the image |
|
||||
| ScaleLatents | Scales latents by a given factor |
|
||||
| SegmentAnythingProcessor | Applies segment anything processing to image |
|
||||
| ShowImage | Displays a provided image, and passes it forward in the pipeline |
|
||||
| StepParamEasing | Experimental per-step parameter for easing for denoising steps |
|
||||
| Subtract | Subtracts two numbers |
|
||||
| TextToLatents | Generates latents from conditionings |
|
||||
| TileResampleProcessor | Bass class for invocations that preprocess images for ControlNet |
|
||||
| Upscale | Upscales an image |
|
||||
| VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput |
|
||||
| ZoeDepthImageProcessor | Applies Zoe depth processing to image |
|
||||
|
||||
## Node Grouping Concepts
|
||||
|
||||
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
|
||||
|
||||
### Noise
|
||||
|
||||
As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input.
|
||||
|
||||

|
||||
|
||||
### Conditioning
|
||||
|
||||
As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node.
|
||||
|
||||

|
||||
|
||||
### Image Space & VAE
|
||||
|
||||
The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space.
|
||||
|
||||

|
||||
|
||||
### Defined & Random Seeds
|
||||
|
||||
It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
|
||||
|
||||

|
||||
|
||||
### Control
|
||||
|
||||
Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet.
|
||||
|
||||

|
||||
|
||||
### LoRA
|
||||
|
||||
The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction.
|
||||
|
||||

|
||||
|
||||
### Scaling
|
||||
|
||||
Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results.
|
||||
|
||||

|
||||
|
||||
### Iteration + Multiple Images as Input
|
||||
|
||||
Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time.
|
||||
|
||||

|
||||
|
||||
### Multiple Image Generation + Random Seeds
|
||||
|
||||
Multiple image generation in the node editor is done using the RandomRange node. In this case, the 'Size' field represents the number of images to generate. As RandomRange produces a collection of integers, we need to add the Iterate node to iterate through the collection.
|
||||
|
||||
To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results.
|
||||
|
||||

|
||||
|
||||
## Examples
|
||||
|
||||
With our knowledge of node grouping and the diffusion process, let’s break down some basic graphs in the nodes editor. Note that a node's options can be overridden by inputs from other nodes. These examples aren't strict rules to follow and only demonstrate some basic configurations.
|
||||
|
||||
### Basic text-to-image Node Graph
|
||||
|
||||

|
||||
|
||||
- Model Loader: A necessity to generating images (as we’ve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE.
|
||||
- Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, ‘dog’), one will output negative (what you don’t want, ‘cat’). They both input the CLIP tokenizer that the Model Loader node outputs.
|
||||
- Noise: Consider this noise A from step one of the text-to-image explanation above. Choose a seed number, width, and height.
|
||||
- TextToLatents: This node takes many inputs for converting and processing text & noise from image space into latent space, hence the name TextTo**Latents**. In this setup, it inputs positive and negative conditioning from the prompt nodes for processing (step 2 above). It inputs noise from the noise node for processing (steps 2 & 3 above). Lastly, it inputs a U-Net from the Model Loader node for processing (step 2 above). It outputs latents for use in the next LatentsToImage node. Choose number of sampler steps, CFG scale, and scheduler.
|
||||
- LatentsToImage: This node takes in processed latents from the TextToLatents node, and the model’s VAE from the Model Loader node which is responsible for decoding latents back into the image space, hence the name LatentsTo**Image**. This node is the last stop, and once the image is decoded, it is saved to the gallery.
|
||||
|
||||
### Basic image-to-image Node Graph
|
||||
|
||||

|
||||
|
||||
- Model Loader: Choose a model from the dropdown.
|
||||
- Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before.
|
||||
- ImageToLatents: Upload a source image directly in the node window, via drag'n'drop from the gallery, or passed in as input. The ImageToLatents node inputs the VAE from the Model Loader node to decode the chosen image from image space into latent space, hence the name ImageTo**Latents**. It outputs latents for use in the next LatentsToLatents node. It also outputs the source image's width and height for use in the next Noise node if the final image is to be the same dimensions as the source image.
|
||||
- Noise: A noise tensor is created with the width and height of the source image, and connected to the next LatentsToLatents node. Notice the width and height fields are overridden by the input from the ImageToLatents width and height outputs.
|
||||
- LatentsToLatents: The inputs and options are nearly identical to TextToLatents, except that LatentsToLatents also takes latents as an input. Considering our source image is already converted to latents in the last ImageToLatents node, and text + noise are no longer the only inputs to process, we use the LatentsToLatents node.
|
||||
- LatentsToImage: Like previously, the LatentsToImage node will use the VAE from the Model Loader as input to decode the latents from LatentsToLatents into image space, and save it to the gallery.
|
||||
|
||||
### Basic ControlNet Node Graph
|
||||
|
||||

|
||||
|
||||
- Model Loader
|
||||
- Prompt (Compel)
|
||||
- Noise: Width and height of the CannyImageProcessor ControlNet image is passed in to set the dimensions of the noise passed to TextToLatents.
|
||||
- CannyImageProcessor: The CannyImageProcessor node is used to process the source image being used as a ControlNet. Each ControlNet processor node applies control in different ways, and has some different options to configure. Width and height are passed to noise, as mentioned. The processed ControlNet image is output to the ControlNet node.
|
||||
- ControlNet: Select the type of control model. In this case, canny is chosen as the CannyImageProcessor was used to generate the ControlNet image. Configure the control node options, and pass the control output to TextToLatents.
|
||||
- TextToLatents: Similar to the basic text-to-image example, except ControlNet is passed to the control input edge point.
|
||||
- LatentsToImage
|
||||
@@ -4,35 +4,13 @@ title: Postprocessing
|
||||
|
||||
# :material-image-edit: Postprocessing
|
||||
|
||||
## Intro
|
||||
|
||||
This extension provides the ability to restore faces and upscale images.
|
||||
This sections details the ability to improve faces and upscale images.
|
||||
|
||||
## Face Fixing
|
||||
|
||||
The default face restoration module is GFPGAN. The default upscale is
|
||||
Real-ESRGAN. For an alternative face restoration module, see
|
||||
[CodeFormer Support](#codeformer-support) below.
|
||||
As of InvokeAI 3.0, the easiest way to improve faces created during image generation is through the Inpainting functionality of the Unified Canvas. Simply add the image containing the faces that you would like to improve to the canvas, mask the face to be improved and run the invocation. For best results, make sure to use an inpainting specific model; these are usually identified by the "-inpainting" term in the model name.
|
||||
|
||||
As of version 1.14, environment.yaml will install the Real-ESRGAN package into
|
||||
the standard install location for python packages, and will put GFPGAN into a
|
||||
subdirectory of "src" in the InvokeAI directory. Upscaling with Real-ESRGAN
|
||||
should "just work" without further intervention. Simply indicate the desired scale on
|
||||
the popup in the Web GUI.
|
||||
|
||||
**GFPGAN** requires a series of downloadable model files to work. These are
|
||||
loaded when you run `invokeai-configure`. If GFPAN is failing with an
|
||||
error, please run the following from the InvokeAI directory:
|
||||
|
||||
```bash
|
||||
invokeai-configure
|
||||
```
|
||||
|
||||
If you do not run this script in advance, the GFPGAN module will attempt to
|
||||
download the models files the first time you try to perform facial
|
||||
reconstruction.
|
||||
|
||||
### Upscaling
|
||||
## Upscaling
|
||||
|
||||
Open the upscaling dialog by clicking on the "expand" icon located
|
||||
above the image display area in the Web UI:
|
||||
@@ -41,82 +19,23 @@ above the image display area in the Web UI:
|
||||

|
||||
</figure>
|
||||
|
||||
There are three different upscaling parameters that you can
|
||||
adjust. The first is the scale itself, either 2x or 4x.
|
||||
The default upscaling option is Real-ESRGAN x2 Plus, which will scale your image by a factor of two. This means upscaling a 512x512 image will result in a new 1024x1024 image.
|
||||
|
||||
The second is the "Denoising Strength." Higher values will smooth out
|
||||
the image and remove digital chatter, but may lose fine detail at
|
||||
higher values.
|
||||
Other options are the x4 upscalers, which will scale your image by a factor of 4.
|
||||
|
||||
Third, "Upscale Strength" allows you to adjust how the You can set the
|
||||
scaling stength between `0` and `1.0` to control the intensity of the
|
||||
scaling. AI upscalers generally tend to smooth out texture details. If
|
||||
you wish to retain some of those for natural looking results, we
|
||||
recommend using values between `0.5 to 0.8`.
|
||||
|
||||
[This figure](../assets/features/upscaling-montage.png) illustrates
|
||||
the effects of denoising and strength. The original image was 512x512,
|
||||
4x scaled to 2048x2048. The "original" version on the upper left was
|
||||
scaled using simple pixel averaging. The remainder use the ESRGAN
|
||||
upscaling algorithm at different levels of denoising and strength.
|
||||
|
||||
<figure markdown>
|
||||
{ width=720 }
|
||||
</figure>
|
||||
|
||||
Both denoising and strength default to 0.75.
|
||||
|
||||
### Face Restoration
|
||||
|
||||
InvokeAI offers alternative two face restoration algorithms,
|
||||
[GFPGAN](https://github.com/TencentARC/GFPGAN) and
|
||||
[CodeFormer](https://huggingface.co/spaces/sczhou/CodeFormer). These
|
||||
algorithms improve the appearance of faces, particularly eyes and
|
||||
mouths. Issues with faces are less common with the latest set of
|
||||
Stable Diffusion models than with the original 1.4 release, but the
|
||||
restoration algorithms can still make a noticeable improvement in
|
||||
certain cases. You can also apply restoration to old photographs you
|
||||
upload.
|
||||
|
||||
To access face restoration, click the "smiley face" icon in the
|
||||
toolbar above the InvokeAI image panel. You will be presented with a
|
||||
dialog that offers a choice between the two algorithm and sliders that
|
||||
allow you to adjust their parameters. Alternatively, you may open the
|
||||
left-hand accordion panel labeled "Face Restoration" and have the
|
||||
restoration algorithm of your choice applied to generated images
|
||||
automatically.
|
||||
|
||||
|
||||
Like upscaling, there are a number of parameters that adjust the face
|
||||
restoration output. GFPGAN has a single parameter, `strength`, which
|
||||
controls how much the algorithm is allowed to adjust the
|
||||
image. CodeFormer has two parameters, `strength`, and `fidelity`,
|
||||
which together control the quality of the output image as described in
|
||||
the [CodeFormer project
|
||||
page](https://shangchenzhou.com/projects/CodeFormer/). Default values
|
||||
are 0.75 for both parameters, which achieves a reasonable balance
|
||||
between changing the image too much and not enough.
|
||||
|
||||
[This figure](../assets/features/restoration-montage.png) illustrates
|
||||
the effects of adjusting GFPGAN and CodeFormer parameters.
|
||||
|
||||
<figure markdown>
|
||||
{ width=720 }
|
||||
</figure>
|
||||
|
||||
!!! note
|
||||
|
||||
GFPGAN and Real-ESRGAN are both memory intensive. In order to avoid crashes and memory overloads
|
||||
Real-ESRGAN is memory intensive. In order to avoid crashes and memory overloads
|
||||
during the Stable Diffusion process, these effects are applied after Stable Diffusion has completed
|
||||
its work.
|
||||
|
||||
In single image generations, you will see the output right away but when you are using multiple
|
||||
iterations, the images will first be generated and then upscaled and face restored after that
|
||||
iterations, the images will first be generated and then upscaled after that
|
||||
process is complete. While the image generation is taking place, you will still be able to preview
|
||||
the base images.
|
||||
|
||||
## How to disable
|
||||
|
||||
If, for some reason, you do not wish to load the GFPGAN and/or ESRGAN libraries,
|
||||
you can disable them on the invoke.py command line with the `--no_restore` and
|
||||
`--no_esrgan` options, respectively.
|
||||
If, for some reason, you do not wish to load the ESRGAN libraries,
|
||||
you can disable them on the invoke.py command line with the `--no_esrgan` options.
|
||||
|
||||
@@ -4,80 +4,6 @@ title: Prompting-Features
|
||||
|
||||
# :octicons-command-palette-24: Prompting-Features
|
||||
|
||||
## **Negative and Unconditioned Prompts**
|
||||
|
||||
Any words between a pair of square brackets will instruct Stable
|
||||
Diffusion to attempt to ban the concept from the generated image. The
|
||||
same effect is achieved by placing words in the "Negative Prompts"
|
||||
textbox in the Web UI.
|
||||
|
||||
```text
|
||||
this is a test prompt [not really] to make you understand [cool] how this works.
|
||||
```
|
||||
|
||||
In the above statement, the words 'not really cool` will be ignored by Stable
|
||||
Diffusion.
|
||||
|
||||
Here's a prompt that depicts what it does.
|
||||
|
||||
original prompt:
|
||||
|
||||
`#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve"`
|
||||
|
||||
`#!bash parameters: steps=20, dimensions=512x768, CFG=7.5, Scheduler=k_euler_a, seed=1654590180`
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
That image has a woman, so if we want the horse without a rider, we can
|
||||
influence the image not to have a woman by putting [woman] in the prompt, like
|
||||
this:
|
||||
|
||||
`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman]"`
|
||||
(same parameters as above)
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
That's nice - but say we also don't want the image to be quite so blue. We can
|
||||
add "blue" to the list of negative prompts, so it's now [woman blue]:
|
||||
|
||||
`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman blue]"`
|
||||
(same parameters as above)
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
Getting close - but there's no sense in having a saddle when our horse doesn't
|
||||
have a rider, so we'll add one more negative prompt: [woman blue saddle].
|
||||
|
||||
`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve [woman blue saddle]"`
|
||||
(same parameters as above)
|
||||
|
||||
<figure markdown>
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
!!! notes "Notes about this feature:"
|
||||
|
||||
* The only requirement for words to be ignored is that they are in between a pair of square brackets.
|
||||
* You can provide multiple words within the same bracket.
|
||||
* You can provide multiple brackets with multiple words in different places of your prompt. That works just fine.
|
||||
* To improve typical anatomy problems, you can add negative prompts like `[bad anatomy, extra legs, extra arms, extra fingers, poorly drawn hands, poorly drawn feet, disfigured, out of frame, tiling, bad art, deformed, mutated]`.
|
||||
|
||||
---
|
||||
|
||||
## **Prompt Syntax Features**
|
||||
|
||||
The InvokeAI prompting language has the following features:
|
||||
@@ -102,9 +28,6 @@ The following syntax is recognised:
|
||||
`a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++`
|
||||
is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2),
|
||||
etc.)
|
||||
- attention also applies to `[unconditioning]` so
|
||||
`a tall thin man picking apricots [(ladder)0.01]` will _very gently_ nudge SD
|
||||
away from trying to draw the man on a ladder
|
||||
|
||||
You can use this to increase or decrease the amount of something. Starting from
|
||||
this prompt of `a man picking apricots from a tree`, let's see what happens if
|
||||
@@ -150,7 +73,7 @@ Or, alternatively, with more man:
|
||||
| ---------------------------------------------- | ---------------------------------------------- | ---------------------------------------------- | ---------------------------------------------- |
|
||||
|  |  |  |  |
|
||||
|
||||
### Blending between prompts
|
||||
### Prompt Blending
|
||||
|
||||
- `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)`
|
||||
- The existing prompt blending using `:<weight>` will continue to be supported -
|
||||
@@ -168,6 +91,24 @@ Or, alternatively, with more man:
|
||||
See the section below on "Prompt Blending" for more information about how this
|
||||
works.
|
||||
|
||||
### Prompt Conjunction
|
||||
Join multiple clauses together to create a conjoined prompt. Each clause will be passed to CLIP separately.
|
||||
|
||||
For example, the prompt:
|
||||
|
||||
```bash
|
||||
"A mystical valley surround by towering granite cliffs, watercolor, warm"
|
||||
```
|
||||
|
||||
Can be used with .and():
|
||||
```bash
|
||||
("A mystical valley", "surround by towering granite cliffs", "watercolor", "warm").and()
|
||||
```
|
||||
|
||||
Each will give you different results - try them out and see what you prefer!
|
||||
|
||||
|
||||
|
||||
### Cross-Attention Control ('prompt2prompt')
|
||||
|
||||
Sometimes an image you generate is almost right, and you just want to change one
|
||||
@@ -190,7 +131,7 @@ For example, consider the prompt `a cat.swap(dog) playing with a ball in the for
|
||||
|
||||
- For multiple word swaps, use parentheses: `a (fluffy cat).swap(barking dog) playing with a ball in the forest`.
|
||||
- To swap a comma, use quotes: `a ("fluffy, grey cat").swap("big, barking dog") playing with a ball in the forest`.
|
||||
- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
|
||||
- Supports options `t_start` and `t_end` (each 0-1) loosely corresponding to (bloc97's)[(https://github.com/bloc97/CrossAttentionControl)] `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to
|
||||
intuitively understand. `t_start` and `t_end` are used to control on which steps cross-attention control should run. With the default values `t_start=0` and `t_end=1`, cross-attention control is active on every step of image generation. Other values can be used to turn cross-attention control off for part of the image generation process.
|
||||
- For example, if doing a diffusion with 10 steps for the prompt is `a cat.swap(dog, t_start=0.3, t_end=1.0) playing with a ball in the forest`, the first 3 steps will be run as `a cat playing with a ball in the forest`, while the last 7 steps will run as `a dog playing with a ball in the forest`, but the pixels that represent `dog` will be locked to the pixels that would have represented `cat` if the `cat` prompt had been used instead.
|
||||
- Conversely, for `a cat.swap(dog, t_start=0, t_end=0.7) playing with a ball in the forest`, the first 7 steps will run as `a dog playing with a ball in the forest` with the pixels that represent `dog` locked to the same pixels that would have represented `cat` if the `cat` prompt was being used instead. The final 3 steps will just run `a cat playing with a ball in the forest`.
|
||||
@@ -201,7 +142,7 @@ Prompt2prompt `.swap()` is not compatible with xformers, which will be temporari
|
||||
The `prompt2prompt` code is based off
|
||||
[bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
|
||||
|
||||
### Escaping parantheses () and speech marks ""
|
||||
### Escaping parentheses () and speech marks ""
|
||||
|
||||
If the model you are using has parentheses () or speech marks "" as part of its
|
||||
syntax, you will need to "escape" these using a backslash, so that`(my_keyword)`
|
||||
@@ -212,23 +153,16 @@ the parentheses as part of the prompt syntax and it will get confused.
|
||||
|
||||
## **Prompt Blending**
|
||||
|
||||
You may blend together different sections of the prompt to explore the AI's
|
||||
You may blend together prompts to explore the AI's
|
||||
latent semantic space and generate interesting (and often surprising!)
|
||||
variations. The syntax is:
|
||||
|
||||
```bash
|
||||
blue sphere:0.25 red cube:0.75 hybrid
|
||||
("prompt #1", "prompt #2").blend(0.25, 0.75)
|
||||
```
|
||||
|
||||
This will tell the sampler to blend 25% of the concept of a blue sphere with 75%
|
||||
of the concept of a red cube. The blend weights can use any combination of
|
||||
integers and floating point numbers, and they do not need to add up to 1.
|
||||
Everything to the left of the `:XX` up to the previous `:XX` is used for
|
||||
merging, so the overall effect is:
|
||||
|
||||
```bash
|
||||
0.25 * "blue sphere" + 0.75 * "white duck" + hybrid
|
||||
```
|
||||
This will tell the sampler to blend 25% of the concept of prompt #1 with 75%
|
||||
of the concept of prompt #2. It is recommended to keep the sum of the weights to around 1.0, but interesting things might happen if you go outside of this range.
|
||||
|
||||
Because you are exploring the "mind" of the AI, the AI's way of mixing two
|
||||
concepts may not match yours, leading to surprising effects. To illustrate, here
|
||||
@@ -236,13 +170,14 @@ are three images generated using various combinations of blend weights. As
|
||||
usual, unless you fix the seed, the prompts will give you different results each
|
||||
time you run them.
|
||||
|
||||
<figure markdown>
|
||||
Let's examine how this affects image generation results:
|
||||
|
||||
### "blue sphere, red cube, hybrid"
|
||||
|
||||
</figure>
|
||||
```bash
|
||||
"blue sphere, red cube, hybrid"
|
||||
```
|
||||
|
||||
This example doesn't use melding at all and represents the default way of mixing
|
||||
This example doesn't use blending at all and represents the default way of mixing
|
||||
concepts.
|
||||
|
||||
<figure markdown>
|
||||
@@ -251,55 +186,47 @@ concepts.
|
||||
|
||||
</figure>
|
||||
|
||||
It's interesting to see how the AI expressed the concept of "cube" as the four
|
||||
quadrants of the enclosing frame. If you look closely, there is depth there, so
|
||||
the enclosing frame is actually a cube.
|
||||
It's interesting to see how the AI expressed the concept of "cube" within the sphere. If you look closely, there is depth there, so the enclosing frame is actually a cube.
|
||||
|
||||
<figure markdown>
|
||||
|
||||
### "blue sphere:0.25 red cube:0.75 hybrid"
|
||||
```bash
|
||||
("blue sphere", "red cube").blend(0.25, 0.75)
|
||||
```
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
Now that's interesting. We get neither a blue sphere nor a red cube, but a red
|
||||
sphere embedded in a brick wall, which represents a melding of concepts within
|
||||
the AI's "latent space" of semantic representations. Where is Ludwig
|
||||
Wittgenstein when you need him?
|
||||
Now that's interesting. We get an image with a resemblance of a red cube, with a hint of blue shadows which represents a melding of concepts within the AI's "latent space" of semantic representations.
|
||||
|
||||
<figure markdown>
|
||||
|
||||
### "blue sphere:0.75 red cube:0.25 hybrid"
|
||||
```bash
|
||||
("blue sphere", "red cube").blend(0.75, 0.25)
|
||||
```
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
Definitely more blue-spherey. The cube is gone entirely, but it's really cool
|
||||
abstract art.
|
||||
Definitely more blue-spherey.
|
||||
|
||||
<figure markdown>
|
||||
|
||||
### "blue sphere:0.5 red cube:0.5 hybrid"
|
||||
```bash
|
||||
("blue sphere", "red cube").blend(0.5, 0.5)
|
||||
```
|
||||
</figure>
|
||||
|
||||
<figure markdown>
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
Whoa...! I see blue and red, but no spheres or cubes. Is the word "hybrid"
|
||||
summoning up the concept of some sort of scifi creature? Let's find out.
|
||||
|
||||
<figure markdown>
|
||||
Whoa...! I see blue and red, and if I squint, spheres and cubes.
|
||||
|
||||
### "blue sphere:0.5 red cube:0.5"
|
||||
|
||||

|
||||
|
||||
</figure>
|
||||
|
||||
Indeed, removing the word "hybrid" produces an image that is more like what we'd
|
||||
expect.
|
||||
|
||||
## Dynamic Prompts
|
||||
|
||||
|
||||
@@ -30,10 +30,6 @@ image output.
|
||||
### * [Image-to-Image Guide](IMG2IMG.md)
|
||||
Use a seed image to build new creations in the CLI.
|
||||
|
||||
### * [Generating Variations](VARIATIONS.md)
|
||||
Have an image you like and want to generate many more like it? Variations
|
||||
are the ticket.
|
||||
|
||||
## Model Management
|
||||
|
||||
### * [Model Installation](../installation/050_INSTALLING_MODELS.md)
|
||||
|
||||
27
docs/help/diffusion.md
Normal file
@@ -0,0 +1,27 @@
|
||||
Taking the time to understand the diffusion process will help you to understand how to more effectively use InvokeAI.
|
||||
|
||||
There are two main ways Stable Diffusion works - with images, and latents.
|
||||
|
||||
Image space represents images in pixel form that you look at. Latent space represents compressed inputs. It’s in latent space that Stable Diffusion processes images. A VAE (Variational Auto Encoder) is responsible for compressing and encoding inputs into latent space, as well as decoding outputs back into image space.
|
||||
|
||||
To fully understand the diffusion process, we need to understand a few more terms: UNet, CLIP, and conditioning.
|
||||
|
||||
A U-Net is a model trained on a large number of latent images with with known amounts of random noise added. This means that the U-Net can be given a slightly noisy image and it will predict the pattern of noise needed to subtract from the image in order to recover the original.
|
||||
|
||||
CLIP is a model that tokenizes and encodes text into conditioning. This conditioning guides the model during the denoising steps to produce a new image.
|
||||
|
||||
The U-Net and CLIP work together during the image generation process at each denoising step, with the U-Net removing noise in such a way that the result is similar to images in the U-Net’s training set, while CLIP guides the U-Net towards creating images that are most similar to the prompt.
|
||||
|
||||
|
||||
When you generate an image using text-to-image, multiple steps occur in latent space:
|
||||
1. Random noise is generated at the chosen height and width. The noise’s characteristics are dictated by seed. This noise tensor is passed into latent space. We’ll call this noise A.
|
||||
2. Using a model’s U-Net, a noise predictor examines noise A, and the words tokenized by CLIP from your prompt (conditioning). It generates its own noise tensor to predict what the final image might look like in latent space. We’ll call this noise B.
|
||||
3. Noise B is subtracted from noise A in an attempt to create a latent image consistent with the prompt. This step is repeated for the number of sampler steps chosen.
|
||||
4. The VAE decodes the final latent image from latent space into image space.
|
||||
|
||||
Image-to-image is a similar process, with only step 1 being different:
|
||||
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
|
||||
|
||||
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).
|
||||
|
||||
A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from the latent image across the sampler steps chosen (step 3 above). Less noise is usually subtracted at higher sampler steps.
|
||||
@@ -49,9 +49,9 @@ title: Home
|
||||
[![github stars badge]][github stars link]
|
||||
[![github forks badge]][github forks link]
|
||||
|
||||
[![CI checks on main badge]][ci checks on main link]
|
||||
<!-- [![CI checks on main badge]][ci checks on main link]
|
||||
[![CI checks on dev badge]][ci checks on dev link]
|
||||
<!-- [![latest commit to dev badge]][latest commit to dev link] -->
|
||||
[![latest commit to dev badge]][latest commit to dev link] -->
|
||||
|
||||
[![github open issues badge]][github open issues link]
|
||||
[![github open prs badge]][github open prs link]
|
||||
|
||||
@@ -264,7 +264,7 @@ experimental versions later.
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***Autoimport FolderLICENSE***
|
||||
- ***LICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
@@ -471,7 +471,7 @@ Then type the following commands:
|
||||
|
||||
=== "NVIDIA System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ title: Installing Manually
|
||||
|
||||
</figure>
|
||||
|
||||
!!! warning "This is for advanced Users"
|
||||
!!! warning "This is for Advanced Users"
|
||||
|
||||
**python experience is mandatory**
|
||||
**Python experience is mandatory**
|
||||
|
||||
## Introduction
|
||||
|
||||
@@ -148,7 +148,7 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -192,8 +192,10 @@ manager, please follow these steps:
|
||||
your outputs.
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
invokeai-configure --root .
|
||||
```
|
||||
|
||||
Don't miss the dot at the end of the command!
|
||||
|
||||
The script `invokeai-configure` will interactively guide you through the
|
||||
process of downloading and installing the weights files needed for InvokeAI.
|
||||
@@ -225,12 +227,6 @@ manager, please follow these steps:
|
||||
|
||||
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
|
||||
|
||||
=== "CLI"
|
||||
|
||||
```bash
|
||||
invokeai
|
||||
```
|
||||
|
||||
=== "local Webserver"
|
||||
|
||||
```bash
|
||||
@@ -243,6 +239,12 @@ manager, please follow these steps:
|
||||
invokeai --web --host 0.0.0.0
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
||||
```bash
|
||||
invokeai
|
||||
```
|
||||
|
||||
If you choose the run the web interface, point your browser at
|
||||
http://localhost:9090 in order to load the GUI.
|
||||
|
||||
@@ -310,7 +312,7 @@ installation protocol (important!)
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
```bash
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -354,7 +356,7 @@ you can do so using this unsupported recipe:
|
||||
mkdir ~/invokeai
|
||||
conda create -n invokeai python=3.10
|
||||
conda activate invokeai
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
invokeai-configure --root ~/invokeai
|
||||
invokeai --root ~/invokeai --web
|
||||
```
|
||||
|
||||
@@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
|
||||
nvidia-cuda-toolkit package. It is out of date and will cause
|
||||
conflicts among the NVIDIA driver and binaries.**
|
||||
|
||||
Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive),
|
||||
and use the target selection wizard to choose your operating system,
|
||||
hardware platform, and preferred installation method (e.g. "local"
|
||||
versus "network").
|
||||
Go to [CUDA Toolkit
|
||||
Downloads](https://developer.nvidia.com/cuda-downloads), and use the
|
||||
target selection wizard to choose your operating system, hardware
|
||||
platform, and preferred installation method (e.g. "local" versus
|
||||
"network").
|
||||
|
||||
This will provide you with a downloadable install file or, depending
|
||||
on your choices, a recipe for downloading and running a install shell
|
||||
@@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/cu117` as described in the [Manual
|
||||
https://download.pytorch.org/whl/cu118` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
## :simple-amd: ROCm
|
||||
|
||||
@@ -4,9 +4,9 @@ title: Installing with Docker
|
||||
|
||||
# :fontawesome-brands-docker: Docker
|
||||
|
||||
!!! warning "For end users"
|
||||
!!! warning "For most users"
|
||||
|
||||
We highly recommend to Install InvokeAI locally using [these instructions](index.md)
|
||||
We highly recommend to Install InvokeAI locally using [these instructions](INSTALLATION.md)
|
||||
|
||||
!!! tip "For developers"
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ installation. Examples:
|
||||
invokeai-model-install --list controlnet
|
||||
|
||||
# (install the model at the indicated URL)
|
||||
invokeai-model-install --add http://civitai.com/2860
|
||||
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||
|
||||
# (delete the named model)
|
||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||
@@ -170,4 +170,4 @@ elsewhere on disk and they will be autoimported. You can also create
|
||||
subfolders and organize them as you wish.
|
||||
|
||||
The location of the autoimport directories are controlled by settings
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
|
||||
@@ -28,18 +28,21 @@ command line, then just be sure to activate it's virtual environment.
|
||||
Then run the following three commands:
|
||||
|
||||
```sh
|
||||
pip install xformers==0.0.16rc425
|
||||
pip install triton
|
||||
pip install xformers~=0.0.19
|
||||
pip install triton # WON'T WORK ON WINDOWS
|
||||
python -m xformers.info output
|
||||
```
|
||||
|
||||
The first command installs `xformers`, the second installs the
|
||||
`triton` training accelerator, and the third prints out the `xformers`
|
||||
installation status. If all goes well, you'll see a report like the
|
||||
installation status. On Windows, please omit the `triton` package,
|
||||
which is not available on that platform.
|
||||
|
||||
If all goes well, you'll see a report like the
|
||||
following:
|
||||
|
||||
```sh
|
||||
xFormers 0.0.16rc425
|
||||
xFormers 0.0.20
|
||||
memory_efficient_attention.cutlassF: available
|
||||
memory_efficient_attention.cutlassB: available
|
||||
memory_efficient_attention.flshattF: available
|
||||
@@ -48,22 +51,28 @@ memory_efficient_attention.smallkF: available
|
||||
memory_efficient_attention.smallkB: available
|
||||
memory_efficient_attention.tritonflashattF: available
|
||||
memory_efficient_attention.tritonflashattB: available
|
||||
indexing.scaled_index_addF: available
|
||||
indexing.scaled_index_addB: available
|
||||
indexing.index_select: available
|
||||
swiglu.dual_gemm_silu: available
|
||||
swiglu.gemm_fused_operand_sum: available
|
||||
swiglu.fused.p.cpp: available
|
||||
is_triton_available: True
|
||||
is_functorch_available: False
|
||||
pytorch.version: 1.13.1+cu117
|
||||
pytorch.version: 2.0.1+cu118
|
||||
pytorch.cuda: available
|
||||
gpu.compute_capability: 8.6
|
||||
gpu.name: NVIDIA RTX A2000 12GB
|
||||
gpu.compute_capability: 8.9
|
||||
gpu.name: NVIDIA GeForce RTX 4070
|
||||
build.info: available
|
||||
build.cuda_version: 1107
|
||||
build.python_version: 3.10.9
|
||||
build.torch_version: 1.13.1+cu117
|
||||
build.cuda_version: 1108
|
||||
build.python_version: 3.10.11
|
||||
build.torch_version: 2.0.1+cu118
|
||||
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
||||
build.env.XFORMERS_BUILD_TYPE: Release
|
||||
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
||||
build.env.NVCC_FLAGS: None
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20
|
||||
build.nvcc_version: 11.8.89
|
||||
source.privacy: open source
|
||||
```
|
||||
|
||||
@@ -83,14 +92,14 @@ installed from source. These instructions were written for a system
|
||||
running Ubuntu 22.04, but other Linux distributions should be able to
|
||||
adapt this recipe.
|
||||
|
||||
#### 1. Install CUDA Toolkit 11.7
|
||||
#### 1. Install CUDA Toolkit 11.8
|
||||
|
||||
You will need the CUDA developer's toolkit in order to compile and
|
||||
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
||||
package.** It is out of date and will cause conflicts among the NVIDIA
|
||||
driver and binaries. Instead install the CUDA Toolkit package provided
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive)
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.8
|
||||
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
|
||||
and use the target selection wizard to choose your platform and Linux
|
||||
distribution. Select an installer type of "runfile (local)" at the
|
||||
last step.
|
||||
@@ -101,17 +110,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
|
||||
x86_64 system is:
|
||||
|
||||
```
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
|
||||
sudo sh cuda_11.7.0_515.43.04_linux.run
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
```
|
||||
|
||||
Rather than cut-and-paste this example, We recommend that you walk
|
||||
through the toolkit wizard in order to get the most up to date
|
||||
installer for your system.
|
||||
|
||||
#### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support
|
||||
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
|
||||
|
||||
If you are using InvokeAI 2.3 or higher, these will already be
|
||||
If you are using InvokeAI 3.0.2 or higher, these will already be
|
||||
installed. If not, you can check whether you have the needed libraries
|
||||
using a quick command. Activate the invokeai virtual environment,
|
||||
either by entering the "developer's console", or manually with a
|
||||
@@ -124,7 +133,7 @@ Then run the command:
|
||||
python -c 'exec("import torch\nprint(torch.__version__)")'
|
||||
```
|
||||
|
||||
If it prints __1.13.1+cu117__ you're good. If not, you can install the
|
||||
If it prints __1.13.1+cu118__ you're good. If not, you can install the
|
||||
most up to date libraries with this command:
|
||||
|
||||
```sh
|
||||
|
||||
@@ -25,10 +25,10 @@ This method is recommended for experienced users and developers
|
||||
#### [Docker Installation](040_INSTALL_DOCKER.md)
|
||||
This method is recommended for those familiar with running Docker containers
|
||||
### Other Installation Guides
|
||||
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
|
||||
- [XFormers](installation/070_INSTALL_XFORMERS.md)
|
||||
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
|
||||
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
|
||||
- [PyPatchMatch](060_INSTALL_PATCHMATCH.md)
|
||||
- [XFormers](070_INSTALL_XFORMERS.md)
|
||||
- [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md)
|
||||
- [Installing New Models](050_INSTALLING_MODELS.md)
|
||||
|
||||
## :fontawesome-solid-computer: Hardware Requirements
|
||||
|
||||
|
||||
7
docs/javascripts/tablesort.js
Normal file
@@ -0,0 +1,7 @@
|
||||
document$.subscribe(function() {
|
||||
var tables = document.querySelectorAll("article table:not([class])")
|
||||
tables.forEach(function(table) {
|
||||
new Tablesort(table)
|
||||
})
|
||||
})
|
||||
|
||||
68
docs/nodes/NODES.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Using the Node Editor
|
||||
|
||||
The nodes editor is a blank canvas allowing for the use of individual functions and image transformations to control the image generation workflow. Nodes take in inputs on the left side of the node, and return an output on the right side of the node. A node graph is composed of multiple nodes that are connected together to create a workflow. Nodes' inputs and outputs are connected by dragging connectors from node to node. Inputs and outputs are color coded for ease of use.
|
||||
|
||||
To better understand how nodes are used, think of how an electric power bar works. It takes in one input (electricity from a wall outlet) and passes it to multiple devices through multiple outputs. Similarly, a node could have multiple inputs and outputs functioning at the same (or different) time, but all node outputs pass information onward like a power bar passes electricity. Not all outputs are compatible with all inputs, however - Each node has different constraints on how it is expecting to input/output information. In general, node outputs are colour-coded to match compatible inputs of other nodes.
|
||||
|
||||
|
||||
If you're not familiar with Diffusion, take a look at our [Diffusion Overview.](../help/diffusion.md) Understanding how diffusion works will enable you to more easily use the Nodes Editor and build workflows to suit your needs.
|
||||
|
||||
## Important Concepts
|
||||
|
||||
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
|
||||
|
||||
### Noise
|
||||
|
||||
An initial noise tensor is necessary for the latent diffusion process. As a result, the Denoising node requires a noise node input.
|
||||
|
||||

|
||||
|
||||
### Text Prompt Conditioning
|
||||
|
||||
Conditioning is necessary for the latent diffusion process, whether empty or not. As a result, the Denoising node requires positive and negative conditioning inputs. Conditioning is reliant on a CLIP text encoder provided by the Model Loader node.
|
||||
|
||||

|
||||
|
||||
### Image to Latents & VAE
|
||||
|
||||
The ImageToLatents node takes in a pixel image and a VAE and outputs a latents. The LatentsToImage node does the opposite, taking in a latents and a VAE and outpus a pixel image.
|
||||
|
||||

|
||||
|
||||
### Defined & Random Seeds
|
||||
|
||||
It is common to want to use both the same seed (for continuity) and random seeds (for variety). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
|
||||
|
||||

|
||||
|
||||
### ControlNet
|
||||
|
||||
The ControlNet node outputs a Control, which can be provided as input to non-image *ToLatents nodes. Depending on the type of ControlNet desired, ControlNet nodes usually require an image processor node, such as a Canny Processor or Depth Processor, which prepares an input image for use with ControlNet.
|
||||
|
||||

|
||||
|
||||
### LoRA
|
||||
|
||||
The Lora Loader node lets you load a LoRA and pass it as output.A LoRA provides fine-tunes to the UNet and text encoder weights that augment the base model’s image and text vocabularies.
|
||||
|
||||

|
||||
|
||||
### Scaling
|
||||
|
||||
Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. Upscaling is the process of enlarging an image and adding more detail. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results.
|
||||
|
||||

|
||||
|
||||
### Iteration + Multiple Images as Input
|
||||
|
||||
Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and use them in an image generation workflow.
|
||||
|
||||

|
||||
|
||||
### Multiple Image Generation + Random Seeds
|
||||
|
||||
Multiple image generation in the node editor is done using the RandomRange node. In this case, the 'Size' field represents the number of images to generate. As RandomRange produces a collection of integers, we need to add the Iterate node to iterate through the collection.
|
||||
|
||||
To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results.
|
||||
|
||||

|
||||
80
docs/nodes/comfyToInvoke.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# ComfyUI to InvokeAI
|
||||
|
||||
If you're coming to InvokeAI from ComfyUI, welcome! You'll find things are similar but different - the good news is that you already know how things should work, and it's just a matter of wiring them up!
|
||||
|
||||
Some things to note:
|
||||
|
||||
- InvokeAI's nodes tend to be more granular than default nodes in Comfy. This means each node in Invoke will do a specific task and you might need to use multiple nodes to achieve the same result. The added granularity improves the control you have have over your workflows.
|
||||
- InvokeAI's backend and ComfyUI's backend are very different which means Comfy workflows are not able to be imported into InvokeAI. However, we have created a [list of popular workflows](exampleWorkflows.md) for you to get started with Nodes in InvokeAI!
|
||||
|
||||
## Node Equivalents:
|
||||
|
||||
| Comfy UI Category | ComfyUI Node | Invoke Equivalent |
|
||||
|:---------------------------------- |:---------------------------------- | :----------------------------------|
|
||||
| Sampling |KSampler |Denoise Latents|
|
||||
| Sampling |Ksampler Advanced|Denoise Latents |
|
||||
| Loaders |Load Checkpoint | Main Model Loader _or_ SDXL Main Model Loader|
|
||||
| Loaders |Load VAE | VAE Loader |
|
||||
| Loaders |Load Lora | LoRA Loader _or_ SDXL Lora Loader|
|
||||
| Loaders |Load ControlNet Model | ControlNet|
|
||||
| Loaders |Load ControlNet Model (diff) | ControlNet|
|
||||
| Loaders |Load Style Model | Reference Only ControlNet will be coming in a future version of InvokeAI|
|
||||
| Loaders |unCLIPCheckpointLoader | N/A |
|
||||
| Loaders |GLIGENLoader | N/A |
|
||||
| Loaders |Hypernetwork Loader | N/A |
|
||||
| Loaders |Load Upscale Model | Occurs within "Upscale (RealESRGAN)"|
|
||||
|Conditioning |CLIP Text Encode (Prompt) | Compel (Prompt) or SDXL Compel (Prompt) |
|
||||
|Conditioning |CLIP Set Last Layer | CLIP Skip|
|
||||
|Conditioning |Conditioning (Average) | Use the .blend() feature of prompts |
|
||||
|Conditioning |Conditioning (Combine) | N/A |
|
||||
|Conditioning |Conditioning (Concat) | See the Prompt Tools Community Node|
|
||||
|Conditioning |Conditioning (Set Area) | N/A |
|
||||
|Conditioning |Conditioning (Set Mask) | Mask Edge |
|
||||
|Conditioning |CLIP Vision Encode | N/A |
|
||||
|Conditioning |unCLIPConditioning | N/A |
|
||||
|Conditioning |Apply ControlNet | ControlNet |
|
||||
|Conditioning |Apply ControlNet (Advanced) | ControlNet |
|
||||
|Latent |VAE Decode | Latents to Image|
|
||||
|Latent |VAE Encode | Image to Latents |
|
||||
|Latent |Empty Latent Image | Noise |
|
||||
|Latent |Upscale Latent |Resize Latents |
|
||||
|Latent |Upscale Latent By |Scale Latents |
|
||||
|Latent |Latent Composite | Blend Latents |
|
||||
|Latent |LatentCompositeMasked | N/A |
|
||||
|Image |Save Image | Image |
|
||||
|Image |Preview Image |Current |
|
||||
|Image |Load Image | Image|
|
||||
|Image |Empty Image| Blank Image |
|
||||
|Image |Invert Image | Invert Lerp Image |
|
||||
|Image |Batch Images | Link "Image" nodes into an "Image Collection" node |
|
||||
|Image |Pad Image for Outpainting | Outpainting is easily accomplished in the Unified Canvas |
|
||||
|Image |ImageCompositeMasked | Paste Image |
|
||||
|Image | Upscale Image | Resize Image |
|
||||
|Image | Upscale Image By | Upscale Image |
|
||||
|Image | Upscale Image (using Model) | Upscale Image |
|
||||
|Image | ImageBlur | Blur Image |
|
||||
|Image | ImageQuantize | N/A |
|
||||
|Image | ImageSharpen | N/A |
|
||||
|Image | Canny | Canny Processor |
|
||||
|Mask |Load Image (as Mask) | Image |
|
||||
|Mask |Convert Mask to Image | Image|
|
||||
|Mask |Convert Image to Mask | Image |
|
||||
|Mask |SolidMask | N/A |
|
||||
|Mask |InvertMask |Invert Lerp Image |
|
||||
|Mask |CropMask | Crop Image |
|
||||
|Mask |MaskComposite | Combine Mask |
|
||||
|Mask |FeatherMask | Blur Image |
|
||||
|Advanced | Load CLIP | Main Model Loader _or_ SDXL Main Model Loader|
|
||||
|Advanced | UNETLoader | Main Model Loader _or_ SDXL Main Model Loader|
|
||||
|Advanced | DualCLIPLoader | Main Model Loader _or_ SDXL Main Model Loader|
|
||||
|Advanced | Load Checkpoint | Main Model Loader _or_ SDXL Main Model Loader |
|
||||
|Advanced | ConditioningZeroOut | N/A |
|
||||
|Advanced | ConditioningSetTimestepRange | N/A |
|
||||
|Advanced | CLIPTextEncodeSDXLRefiner | Compel (Prompt) or SDXL Compel (Prompt) |
|
||||
|Advanced | CLIPTextEncodeSDXL |Compel (Prompt) or SDXL Compel (Prompt) |
|
||||
|Advanced | ModelMergeSimple | Model Merging is available in the Model Manager |
|
||||
|Advanced | ModelMergeBlocks | Model Merging is available in the Model Manager|
|
||||
|Advanced | CheckpointSave | Model saving is available in the Model Manager|
|
||||
|Advanced | CLIPMergeSimple | N/A |
|
||||
|
||||
|
||||
@@ -2,17 +2,13 @@
|
||||
|
||||
These are nodes that have been developed by the community, for the community. If you're not sure what a node is, you can learn more about nodes [here](overview.md).
|
||||
|
||||
If you'd like to submit a node for the community, please refer to the [node creation overview](./overview.md#contributing-nodes).
|
||||
If you'd like to submit a node for the community, please refer to the [node creation overview](contributingNodes.md).
|
||||
|
||||
To download a node, simply download the `.py` node file from the link and add it to the `invokeai/app/invocations/` folder in your Invoke AI install location. Along with the node, an example node graph should be provided to help you get started with the node.
|
||||
To download a node, simply download the `.py` node file from the link and add it to the `invokeai/app/invocations` folder in your Invoke AI install location. Along with the node, an example node graph should be provided to help you get started with the node.
|
||||
|
||||
To use a community node graph, download the the `.json` node graph file and load it into Invoke AI via the **Load Nodes** button on the Node Editor.
|
||||
|
||||
## Disclaimer
|
||||
|
||||
The nodes linked below have been developed and contributed by members of the Invoke AI community. While we strive to ensure the quality and safety of these contributions, we do not guarantee the reliability or security of the nodes. If you have issues or concerns with any of the nodes below, please raise it on GitHub or in the Discord.
|
||||
|
||||
## List of Nodes
|
||||
## Community Nodes
|
||||
|
||||
### FaceTools
|
||||
|
||||
@@ -34,6 +30,33 @@ The nodes linked below have been developed and contributed by members of the Inv
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
||||
|
||||
<hr>
|
||||
|
||||
### Retroize
|
||||
|
||||
**Description:** Retroize is a collection of nodes for InvokeAI to "Retroize" images. Any image can be given a fresh coat of retro paint with these nodes, either from your gallery or from within the graph itself. It includes nodes to pixelize, quantize, palettize, and ditherize images; as well as to retrieve palettes from existing images.
|
||||
|
||||
**Node Link:** https://github.com/Ar7ific1al/invokeai-retroizeinode/
|
||||
|
||||
**Retroize Output Examples**
|
||||
|
||||

|
||||
|
||||
--------------------------------
|
||||
### GPT2RandomPromptMaker
|
||||
|
||||
**Description:** A node for InvokeAI utilizes the GPT-2 language model to generate random prompts based on a provided seed and context.
|
||||
|
||||
**Node Link:** https://github.com/mickr777/GPT2RandomPromptMaker
|
||||
|
||||
**Output Examples**
|
||||
|
||||
Generated Prompt: An enchanted weapon will be usable by any character regardless of their alignment.
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Example Node Template
|
||||
|
||||
@@ -47,7 +70,12 @@ The nodes linked below have been developed and contributed by members of the Inv
|
||||
|
||||
{: style="height:115px;width:240px"}
|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
The nodes linked have been developed and contributed by members of the Invoke AI community. While we strive to ensure the quality and safety of these contributions, we do not guarantee the reliability or security of the nodes. If you have issues or concerns with any of the nodes below, please raise it on GitHub or in the Discord.
|
||||
|
||||
|
||||
## Help
|
||||
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).
|
||||
|
||||
|
||||
|
||||
27
docs/nodes/contributingNodes.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# Contributing Nodes
|
||||
|
||||
To learn about the specifics of creating a new node, please visit our [Node creation documentation](../contributing/INVOCATIONS.md).
|
||||
|
||||
Once you’ve created a node and confirmed that it behaves as expected locally, follow these steps:
|
||||
|
||||
- Make sure the node is contained in a new Python (.py) file
|
||||
- Submit a pull request with a link to your node in GitHub against the `nodes` branch to add the node to the [Community Nodes](Community Nodes) list
|
||||
- Make sure you are following the template below and have provided all relevant details about the node and what it does.
|
||||
- A maintainer will review the pull request and node. If the node is aligned with the direction of the project, you might be asked for permission to include it in the core project.
|
||||
|
||||
### Community Node Template
|
||||
|
||||
```markdown
|
||||
--------------------------------
|
||||
### Super Cool Node Template
|
||||
|
||||
**Description:** This node allows you to do super cool things with InvokeAI.
|
||||
|
||||
**Node Link:** https://github.com/invoke-ai/InvokeAI/fake_node.py
|
||||
|
||||
**Example Node Graph:** https://github.com/invoke-ai/InvokeAI/fake_node_graph.json
|
||||
|
||||
**Output Examples**
|
||||
|
||||

|
||||
```
|
||||
97
docs/nodes/defaultNodes.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# List of Default Nodes
|
||||
|
||||
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
|
||||
|
||||
| Node <img width=160 align="right"> | Function |
|
||||
|: ---------------------------------- | :--------------------------------------------------------------------------------------|
|
||||
|Add Integers | Adds two numbers|
|
||||
|Boolean Primitive Collection | A collection of boolean primitive values|
|
||||
|Boolean Primitive | A boolean primitive value|
|
||||
|Canny Processor | Canny edge detection for ControlNet|
|
||||
|CLIP Skip | Skip layers in clip text_encoder model.|
|
||||
|Collect | Collects values into a collection|
|
||||
|Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image.|
|
||||
|Color Primitive | A color primitive value|
|
||||
|Compel Prompt | Parse prompt using compel package to conditioning.|
|
||||
|Conditioning Primitive Collection | A collection of conditioning tensor primitive values|
|
||||
|Conditioning Primitive | A conditioning tensor primitive value|
|
||||
|Content Shuffle Processor | Applies content shuffle processing to image|
|
||||
|ControlNet | Collects ControlNet info to pass to other nodes|
|
||||
|OpenCV Inpaint | Simple inpaint using opencv.|
|
||||
|Denoise Latents | Denoises noisy latents to decodable images|
|
||||
|Divide Integers | Divides two numbers|
|
||||
|Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator|
|
||||
|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.|
|
||||
|Float Primitive Collection | A collection of float primitive values|
|
||||
|Float Primitive | A float primitive value|
|
||||
|Float Range | Creates a range|
|
||||
|HED (softedge) Processor | Applies HED edge detection to image|
|
||||
|Blur Image | Blurs an image|
|
||||
|Extract Image Channel | Gets a channel from an image.|
|
||||
|Image Primitive Collection | A collection of image primitive values|
|
||||
|Convert Image Mode | Converts an image to a different mode.|
|
||||
|Crop Image | Crops an image to a specified box. The box can be outside of the image.|
|
||||
|Image Hue Adjustment | Adjusts the Hue of an image.|
|
||||
|Inverse Lerp Image | Inverse linear interpolation of all pixels of an image|
|
||||
|Image Primitive | An image primitive value|
|
||||
|Lerp Image | Linear interpolation of all pixels of an image|
|
||||
|Image Luminosity Adjustment | Adjusts the Luminosity (Value) of an image.|
|
||||
|Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`.|
|
||||
|Blur NSFW Image | Add blur to NSFW-flagged images|
|
||||
|Paste Image | Pastes an image into another image.|
|
||||
|ImageProcessor | Base class for invocations that preprocess images for ControlNet|
|
||||
|Resize Image | Resizes an image to specific dimensions|
|
||||
|Image Saturation Adjustment | Adjusts the Saturation of an image.|
|
||||
|Scale Image | Scales an image by a factor|
|
||||
|Image to Latents | Encodes an image into latents.|
|
||||
|Add Invisible Watermark | Add an invisible watermark to an image|
|
||||
|Solid Color Infill | Infills transparent areas of an image with a solid color|
|
||||
|PatchMatch Infill | Infills transparent areas of an image using the PatchMatch algorithm|
|
||||
|Tile Infill | Infills transparent areas of an image with tiles of the image|
|
||||
|Integer Primitive Collection | A collection of integer primitive values|
|
||||
|Integer Primitive | An integer primitive value|
|
||||
|Iterate | Iterates over a list of items|
|
||||
|Latents Primitive Collection | A collection of latents tensor primitive values|
|
||||
|Latents Primitive | A latents tensor primitive value|
|
||||
|Latents to Image | Generates an image from latents.|
|
||||
|Leres (Depth) Processor | Applies leres processing to image|
|
||||
|Lineart Anime Processor | Applies line art anime processing to image|
|
||||
|Lineart Processor | Applies line art processing to image|
|
||||
|LoRA Loader | Apply selected lora to unet and text_encoder.|
|
||||
|Main Model Loader | Loads a main model, outputting its submodels.|
|
||||
|Combine Mask | Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.|
|
||||
|Mask Edge | Applies an edge mask to an image|
|
||||
|Mask from Alpha | Extracts the alpha channel of an image as a mask.|
|
||||
|Mediapipe Face Processor | Applies mediapipe face processing to image|
|
||||
|Midas (Depth) Processor | Applies Midas depth processing to image|
|
||||
|MLSD Processor | Applies MLSD processing to image|
|
||||
|Multiply Integers | Multiplies two numbers|
|
||||
|Noise | Generates latent noise.|
|
||||
|Normal BAE Processor | Applies NormalBae processing to image|
|
||||
|ONNX Latents to Image | Generates an image from latents.|
|
||||
|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.|
|
||||
|ONNX Text to Latents | Generates latents from conditionings.|
|
||||
|ONNX Model Loader | Loads a main model, outputting its submodels.|
|
||||
|Openpose Processor | Applies Openpose processing to image|
|
||||
|PIDI Processor | Applies PIDI processing to image|
|
||||
|Prompts from File | Loads prompts from a text file|
|
||||
|Random Integer | Outputs a single random integer.|
|
||||
|Random Range | Creates a collection of random numbers|
|
||||
|Integer Range | Creates a range of numbers from start to stop with step|
|
||||
|Integer Range of Size | Creates a range from start to start + size with step|
|
||||
|Resize Latents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.|
|
||||
|SDXL Compel Prompt | Parse prompt using compel package to conditioning.|
|
||||
|SDXL LoRA Loader | Apply selected lora to unet and text_encoder.|
|
||||
|SDXL Main Model Loader | Loads an sdxl base model, outputting its submodels.|
|
||||
|SDXL Refiner Compel Prompt | Parse prompt using compel package to conditioning.|
|
||||
|SDXL Refiner Model Loader | Loads an sdxl refiner model, outputting its submodels.|
|
||||
|Scale Latents | Scales latents by a given factor.|
|
||||
|Segment Anything Processor | Applies segment anything processing to image|
|
||||
|Show Image | Displays a provided image, and passes it forward in the pipeline.|
|
||||
|Step Param Easing | Experimental per-step parameter easing for denoising steps|
|
||||
|String Primitive Collection | A collection of string primitive values|
|
||||
|String Primitive | A string primitive value|
|
||||
|Subtract Integers | Subtracts two numbers|
|
||||
|Tile Resample Processor | Tile resampler processor|
|
||||
|VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput|
|
||||
|Zoe (Depth) Processor | Applies Zoe depth processing to image|
|
||||
15
docs/nodes/exampleWorkflows.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Example Workflows
|
||||
|
||||
TODO: Will update once uploading workflows is available.
|
||||
|
||||
## Text2Image
|
||||
|
||||
## Image2Image
|
||||
|
||||
## ControlNet
|
||||
|
||||
## Upscaling
|
||||
|
||||
## Inpainting / Outpainting
|
||||
|
||||
## LoRAs
|
||||
@@ -1,42 +1,26 @@
|
||||
# Nodes
|
||||
|
||||
## What are Nodes?
|
||||
An Node is simply a single operation that takes in some inputs and gives
|
||||
out some outputs. We can then chain multiple nodes together to create more
|
||||
An Node is simply a single operation that takes in inputs and returns
|
||||
out outputs. Multiple nodes can be linked together to create more
|
||||
complex functionality. All InvokeAI features are added through nodes.
|
||||
|
||||
This means nodes can be used to easily extend the image generation capabilities of InvokeAI, and allow you build workflows to suit your needs.
|
||||
### Anatomy of a Node
|
||||
|
||||
You can read more about nodes and the node editor [here](../features/NODES.md).
|
||||
Individual nodes are made up of the following:
|
||||
|
||||
- Inputs: Edge points on the left side of the node window where you connect outputs from other nodes.
|
||||
- Outputs: Edge points on the right side of the node window where you connect to inputs on other nodes.
|
||||
- Options: Various options which are either manually configured, or overridden by connecting an output from another node to the input.
|
||||
|
||||
|
||||
## Downloading Nodes
|
||||
To download a new node, visit our list of [Community Nodes](communityNodes.md). These are nodes that have been created by the community, for the community.
|
||||
With nodes, you can can easily extend the image generation capabilities of InvokeAI, and allow you build workflows that suit your needs.
|
||||
|
||||
You can read more about nodes and the node editor [here](../nodes/NODES.md).
|
||||
|
||||
To get started with nodes, take a look at some of our examples for [common workflows](../nodes/exampleWorkflows.md)
|
||||
|
||||
## Downloading New Nodes
|
||||
To download a new node, visit our list of [Community Nodes](../nodes/communityNodes.md). These are nodes that have been created by the community, for the community.
|
||||
|
||||
|
||||
## Contributing Nodes
|
||||
|
||||
To learn about creating a new node, please visit our [Node creation documenation](../contributing/INVOCATIONS.md).
|
||||
|
||||
Once you’ve created a node and confirmed that it behaves as expected locally, follow these steps:
|
||||
* Make sure the node is contained in a new Python (.py) file
|
||||
* Submit a pull request with a link to your node in GitHub against the `nodes` branch to add the node to the [Community Nodes](Community Nodes) list
|
||||
* Make sure you are following the template below and have provided all relevant details about the node and what it does.
|
||||
* A maintainer will review the pull request and node. If the node is aligned with the direction of the project, you might be asked for permission to include it in the core project.
|
||||
|
||||
### Community Node Template
|
||||
|
||||
```markdown
|
||||
--------------------------------
|
||||
### Super Cool Node Template
|
||||
|
||||
**Description:** This node allows you to do super cool things with InvokeAI.
|
||||
|
||||
**Node Link:** https://github.com/invoke-ai/InvokeAI/fake_node.py
|
||||
|
||||
**Example Node Graph:** https://github.com/invoke-ai/InvokeAI/fake_node_graph.json
|
||||
|
||||
**Output Examples**
|
||||
|
||||

|
||||
```
|
||||
|
||||
@@ -348,7 +348,7 @@ class InvokeAiInstance:
|
||||
|
||||
introduction()
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
@@ -407,7 +407,7 @@ def get_pip_from_venv(venv_path: Path) -> str:
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
pip = "Scripts\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
pip = "Scripts\\pip.exe" if OS == "Windows" else "bin/pip"
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
@@ -463,10 +463,10 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
@@ -49,7 +49,7 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
inst.install(**args.__dict__)
|
||||
except KeyboardInterrupt as exc:
|
||||
except KeyboardInterrupt:
|
||||
print("\n")
|
||||
print("Ctrl-C pressed. Aborting.")
|
||||
print("Come back soon!")
|
||||
|
||||
@@ -70,7 +70,7 @@ def confirm_install(dest: Path) -> bool:
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = not Confirm.ask(f"Would you like to pick a different location?", default=False)
|
||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
@@ -90,7 +90,7 @@ def dest_path(dest=None) -> Path:
|
||||
dest = Path(dest).expanduser().resolve()
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = dest.expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
|
||||
dest_confirmed = confirm_install(dest)
|
||||
|
||||
@@ -109,9 +109,9 @@ def dest_path(dest=None) -> Path:
|
||||
)
|
||||
|
||||
console.line()
|
||||
print(f"[orange3]Please select the destination directory for the installation:[/] \[{browse_start}]: ")
|
||||
console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
f">>> ",
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
completer=path_completer,
|
||||
default=str(browse_start) + os.sep,
|
||||
@@ -134,14 +134,14 @@ def dest_path(dest=None) -> Path:
|
||||
try:
|
||||
dest.mkdir(exist_ok=True, parents=True)
|
||||
return dest
|
||||
except PermissionError as exc:
|
||||
print(
|
||||
except PermissionError:
|
||||
console.print(
|
||||
f"Failed to create directory {dest} due to insufficient permissions",
|
||||
style=Style(color="red"),
|
||||
highlight=True,
|
||||
)
|
||||
except OSError as exc:
|
||||
console.print_exception(exc)
|
||||
except OSError:
|
||||
console.print_exception()
|
||||
|
||||
if Confirm.ask("Would you like to try again?"):
|
||||
dest_path(init_path)
|
||||
|
||||
@@ -8,16 +8,13 @@ Preparations:
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
NOTE: At this time we do not recommend Python 3.11. We recommend
|
||||
Version 3.10.9, which has been extensively tested with InvokeAI.
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.9.1, and less than 3.11.
|
||||
is at least 3.9.*, and not higher than 3.11.*.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
@@ -30,6 +28,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -45,7 +44,7 @@ def check_internet() -> bool:
|
||||
try:
|
||||
urllib.request.urlopen(host, timeout=1)
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@@ -55,7 +54,7 @@ logger = InvokeAILogger.getLogger()
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Optional[Invoker] = None
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||
@@ -68,8 +67,9 @@ class ApiDependencies:
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@@ -128,6 +128,7 @@ class ApiDependencies:
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile"]
|
||||
infill_methods = ["tile", "lama"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
||||
@@ -1,24 +1,30 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
class AddImagesToBoardResult(BaseModel):
|
||||
board_id: str = Field(description="The id of the board the images were added to")
|
||||
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(BaseModel):
|
||||
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
operation_id="add_image_to_board",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
async def add_image_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
@@ -28,27 +34,79 @@ async def create_board_image(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
operation_id="remove_image_from_board",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
async def remove_image_from_board(
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/batch",
|
||||
operation_id="add_images_to_board",
|
||||
responses={
|
||||
201: {"description": "Images were added to board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_images_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
try:
|
||||
added_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/batch/delete",
|
||||
operation_id="remove_images_from_board",
|
||||
responses={
|
||||
201: {"description": "Images were removed from board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_images_from_board(
|
||||
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
try:
|
||||
removed_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -37,7 +37,7 @@ async def create_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ async def get_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ async def update_board(
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ async def delete_board(
|
||||
deleted_board_images=deleted_board_images,
|
||||
deleted_images=[],
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete board")
|
||||
|
||||
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
import io
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/",
|
||||
"/upload",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
@@ -55,7 +55,7 @@ async def upload_image(
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except:
|
||||
except Exception:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
@@ -73,11 +73,11 @@ async def upload_image(
|
||||
response.headers["Location"] = image_dto.image_url
|
||||
|
||||
return image_dto
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
@@ -85,7 +85,7 @@ async def delete_image(
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
@@ -97,13 +97,13 @@ async def clear_intermediates() -> int:
|
||||
try:
|
||||
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
||||
return count_deleted
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_name}",
|
||||
"/i/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
@@ -115,12 +115,12 @@ async def update_image(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}",
|
||||
"/i/{image_name}",
|
||||
operation_id="get_image_dto",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
@@ -131,12 +131,12 @@ async def get_image_dto(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/metadata",
|
||||
"/i/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageMetadata,
|
||||
)
|
||||
@@ -147,12 +147,13 @@ async def get_image_metadata(
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/full",
|
||||
@images_router.api_route(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -182,12 +183,12 @@ async def get_image_full(
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/thumbnail",
|
||||
"/i/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -211,12 +212,12 @@ async def get_image_thumbnail(
|
||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/urls",
|
||||
"/i/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
@@ -233,7 +234,7 @@ async def get_image_urls(
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@@ -265,3 +266,62 @@ async def list_image_dtos(
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
|
||||
class DeleteImagesFromListResult(BaseModel):
|
||||
deleted_images: list[str]
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||
async def delete_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesFromListResult:
|
||||
try:
|
||||
deleted_images: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
class ImagesUpdatedFromListResult(BaseModel):
|
||||
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def star_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def unstar_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -104,8 +104,12 @@ async def update_model(
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...invocations import *
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import asyncio
|
||||
import sys
|
||||
from inspect import signature
|
||||
|
||||
import logging
|
||||
@@ -17,35 +16,30 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before
|
||||
# other invokeai initialization messages
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
@@ -128,12 +122,18 @@ def custom_openapi():
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
output_schema["class"] = "output"
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
@@ -141,8 +141,8 @@ def custom_openapi():
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
@@ -225,13 +225,16 @@ def invoke_api():
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
l = logging.getLogger(logname)
|
||||
l.handlers.clear()
|
||||
log = logging.getLogger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
l.addHandler(ch)
|
||||
log.addHandler(ch)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
|
||||
@@ -145,10 +145,10 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(True)
|
||||
except:
|
||||
except AttributeError:
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
pass
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
|
||||
@@ -13,16 +13,8 @@ from pydantic.fields import Field
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before other invokeai initialization messages
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
@@ -37,6 +29,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
@@ -61,10 +54,15 @@ from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
@@ -311,6 +309,7 @@ def invoke_cli():
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
@@ -480,4 +479,7 @@ def invoke_cli():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_cli()
|
||||
|
||||
@@ -3,15 +3,376 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseConfig, BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
blend_alpha = (
|
||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class UIType(str, Enum):
|
||||
"""
|
||||
Type hints for the UI.
|
||||
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||
on adding a new field type, which involves client-side changes.
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Integer = "integer"
|
||||
Float = "float"
|
||||
Boolean = "boolean"
|
||||
String = "string"
|
||||
Array = "array"
|
||||
Image = "ImageField"
|
||||
Latents = "LatentsField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Color = "ColorField"
|
||||
ImageCollection = "ImageCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
StringCollection = "StringCollection"
|
||||
BooleanCollection = "BooleanCollection"
|
||||
# endregion
|
||||
|
||||
# region Models
|
||||
MainModel = "MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VaeModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
# endregion
|
||||
|
||||
# region Iterate/Collect
|
||||
Collection = "Collection"
|
||||
CollectionItem = "CollectionItem"
|
||||
# endregion
|
||||
|
||||
# region Misc
|
||||
FilePath = "FilePath"
|
||||
Enum = "enum"
|
||||
Scheduler = "Scheduler"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are \
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class _InputField(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
ui_order: Optional[int]
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
|
||||
def InputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
This is used internally by the @tags and @title decorator logic. You probably want to use those
|
||||
decorators, though you may add this class to a node definition to specify the title and tags.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
||||
title: Optional[str] = Field(default=None, description="The display name of the node")
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
@@ -25,7 +386,7 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""Base class for all invocation outputs"""
|
||||
|
||||
# All outputs must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
# type: Literal['your_output_name'] # noqa f821
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses_tuple(cls):
|
||||
@@ -38,6 +399,27 @@ class BaseInvocationOutput(BaseModel):
|
||||
toprocess.extend(next_subclasses)
|
||||
return tuple(subclasses)
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
"""Raised when an field which requires a connection did not receive a value."""
|
||||
|
||||
def __init__(self, node_id: str, field_name: str):
|
||||
super().__init__(f"Node {node_id} missing connections for field {field_name}")
|
||||
|
||||
|
||||
class MissingInputException(Exception):
|
||||
"""Raised when an field which requires some input, but did not receive a value."""
|
||||
|
||||
def __init__(self, node_id: str, field_name: str):
|
||||
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
@@ -45,7 +427,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
# All invocations must include a type name like this:
|
||||
# type: Literal['your_output_name']
|
||||
# type: Literal['your_output_name'] # noqa f821
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
@@ -76,70 +458,84 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def get_output_type(cls):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
# fmt: off
|
||||
def __init__(self, **data):
|
||||
# nodes may have required fields, that can accept input from connections
|
||||
# on instantiation of the model, we need to exclude these from validation
|
||||
restore = dict()
|
||||
try:
|
||||
field_names = list(self.__fields__.keys())
|
||||
for field_name in field_names:
|
||||
# if the field is required and may get its value from a connection, exclude it from validation
|
||||
field = self.__fields__[field_name]
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if _input in [Input.Connection, Input.Any] and field.required:
|
||||
if field_name not in data:
|
||||
restore[field_name] = self.__fields__.pop(field_name)
|
||||
# instantiate the node, which will validate the data
|
||||
super().__init__(**data)
|
||||
finally:
|
||||
# restore the removed fields
|
||||
for field_name, field in restore.items():
|
||||
self.__fields__[field_name] = field
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.__fields__.items():
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if field.required and not hasattr(self, field_name):
|
||||
if _input == Input.Connection:
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
# fmt: on
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||
)
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
"image_collection",
|
||||
"vae_model",
|
||||
"lora_model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
T = TypeVar("T", bound=BaseInvocation)
|
||||
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
def title(title: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
|
||||
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.title = title
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class InvocationConfig(BaseConfig):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.tags = list(tags)
|
||||
return cls
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
return wrapper
|
||||
|
||||
@@ -3,58 +3,25 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of images"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Integer Range")
|
||||
@tags("collection", "integer", "range")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
||||
}
|
||||
start: int = InputField(default=0, description="The start of the range")
|
||||
stop: int = InputField(default=10, description="The stop of the range")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
@@ -62,76 +29,44 @@ class RangeInvocation(BaseInvocation):
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
@title("Integer Range of Size")
|
||||
@tags("range", "integer", "size", "collection")
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
type: Literal["range_of_size"] = "range_of_size"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
start: int = InputField(default=0, description="The start of the range")
|
||||
size: int = InputField(default=1, description="The number of values")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
|
||||
|
||||
@title("Random Range")
|
||||
@tags("range", "integer", "random", "collection")
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = InputField(default=1, description="The number of values to generate")
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
||||
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""Load a collection of images and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
images: list[ImageField] = Field(
|
||||
default=[], description="The image collection to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"title": "Image Collection",
|
||||
"images": "image_collection",
|
||||
}
|
||||
},
|
||||
}
|
||||
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
||||
@@ -1,56 +1,40 @@
|
||||
from typing import Literal, Optional, Union, List, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management import ModelType
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.models import ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .model import ClipField
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
# type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
|
||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@@ -60,32 +44,26 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("Compel Prompt")
|
||||
@tags("prompt", "compel")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: ClipField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
@@ -109,12 +87,15 @@ class CompelInvocation(BaseInvocation):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
@@ -165,7 +146,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@@ -173,7 +154,15 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
@@ -183,79 +172,21 @@ class SDXLPromptInvocationBase:
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, clip_field.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.context.model
|
||||
c = torch.zeros(
|
||||
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
||||
dtype=text_encoder_info.context.cache.precision,
|
||||
)
|
||||
if get_pooled:
|
||||
c_pooled = prompt_embeds[0]
|
||||
c_pooled = torch.zeros(
|
||||
(1, cpu_text_encoder.config.hidden_size),
|
||||
dtype=c.dtype,
|
||||
)
|
||||
else:
|
||||
c_pooled = None
|
||||
c = prompt_embeds.hidden_states[-2]
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, None
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
return c, c_pooled, None
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
@@ -271,12 +202,15 @@ class SDXLPromptInvocationBase:
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
@@ -284,8 +218,8 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
@@ -299,7 +233,7 @@ class SDXLPromptInvocationBase:
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=True,
|
||||
requires_pooled=get_pooled,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -333,35 +267,37 @@ class SDXLPromptInvocationBase:
|
||||
return c, c_pooled, ec
|
||||
|
||||
|
||||
@title("SDXL Compel Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -383,39 +319,34 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Compel Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
style: str = InputField(
|
||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||
) # TODO: ?
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
@@ -436,117 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||
|
||||
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=torch.cat([c1, c2], dim=-1),
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec1,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@@ -557,21 +378,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@title("CLIP Skip")
|
||||
@tags("clipskip", "clip", "skip")
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
type: Literal["clip_skip"] = "clip_skip"
|
||||
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
||||
}
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
|
||||
@@ -26,79 +26,31 @@ from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ..models.image import ImageOutput, PILInvocationConfig
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
|
||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
tuple(
|
||||
[
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
)
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
|
||||
|
||||
@@ -110,9 +62,8 @@ class ControlNetModelField(BaseModel):
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@@ -135,60 +86,39 @@ class ControlField(BaseModel):
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
"control_model": "controlnet_model",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@title("ControlNet")
|
||||
@tags("controlnet")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "ControlNet",
|
||||
"tags": ["controlnet", "latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(
|
||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
@@ -204,19 +134,13 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
@@ -255,20 +179,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Canny Processor")
|
||||
@tags("controlnet", "canny")
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
||||
}
|
||||
# Input
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
@@ -276,23 +200,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("HED (softedge) Processor")
|
||||
@tags("controlnet", "hed", "softedge")
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -307,21 +227,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Lineart Processor")
|
||||
@tags("controlnet", "lineart")
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -331,23 +247,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Lineart Anime Processor")
|
||||
@tags("controlnet", "lineart", "anime")
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -359,21 +268,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Openpose Processor")
|
||||
@tags("controlnet", "openpose", "pose")
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -386,22 +291,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Midas (Depth) Processor")
|
||||
@tags("controlnet", "midas", "depth")
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -415,20 +316,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Normal BAE Processor")
|
||||
@tags("controlnet", "normal", "bae")
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -438,22 +335,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("MLSD Processor")
|
||||
@tags("controlnet", "mlsd")
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -467,22 +360,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("PIDI Processor")
|
||||
@tags("controlnet", "pidi")
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -496,26 +385,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Content Shuffle Processor")
|
||||
@tags("controlnet", "contentshuffle")
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
@@ -531,17 +413,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Zoe (Depth) Processor")
|
||||
@tags("controlnet", "zoe", "depth")
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -549,20 +426,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Mediapipe Face Processor")
|
||||
@tags("controlnet", "mediapipe", "face")
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
|
||||
def run_processor(self, image):
|
||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||
@@ -574,23 +447,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Leres (Depth) Processor")
|
||||
@tags("controlnet", "leres", "depth")
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||
# Inputs
|
||||
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = Field(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
||||
}
|
||||
# Inputs
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
@@ -605,21 +474,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
||||
return processed_image
|
||||
|
||||
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
# fmt: off
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
# Inputs
|
||||
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
# fmt: on
|
||||
@title("Tile Resample Processor")
|
||||
@tags("controlnet", "tile")
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
||||
},
|
||||
}
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
|
||||
# Inputs
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(
|
||||
@@ -648,20 +512,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
||||
return processed_image
|
||||
|
||||
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Segment Anything Processor")
|
||||
@tags("controlnet", "segmentanything")
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Segment Anything Processor",
|
||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
|
||||
@@ -5,40 +5,22 @@ from typing import Literal
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
@title("OpenCV Inpaint")
|
||||
@tags("opencv", "inpaint")
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
from .latent import get_scheduler
|
||||
|
||||
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
context: OldModelContext
|
||||
|
||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.context = OldModelContext(
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(BaseInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(
|
||||
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
|
||||
)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the resulting image",
|
||||
)
|
||||
height: int = Field(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the resulting image",
|
||||
)
|
||||
cfg_scale: float = Field(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
||||
fit: bool = Field(
|
||||
default=True,
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
# Inputs
|
||||
mask: Optional[ImageField] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(
|
||||
default=DEFAULT_INFILL_METHOD,
|
||||
description="The method used to infill empty regions (px)",
|
||||
)
|
||||
inpaint_width: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the inpaint region (px)",
|
||||
)
|
||||
inpaint_height: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the inpaint region (px)",
|
||||
)
|
||||
inpaint_fill: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The solid infill method color",
|
||||
)
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context, unet):
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
conditioning = self.get_conditioning(context, model.context.model.unet)
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -1,69 +1,31 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from ..models.image import (
|
||||
ImageCategory,
|
||||
ImageField,
|
||||
ResourceOrigin,
|
||||
PILInvocationConfig,
|
||||
ImageOutput,
|
||||
MaskOutput,
|
||||
)
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Show Image")
|
||||
@tags("image")
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
|
||||
# Metadata
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -79,24 +41,53 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Blank Image")
|
||||
@tags("image")
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
# Metadata
|
||||
type: Literal["blank_image"] = "blank_image"
|
||||
|
||||
# Inputs
|
||||
width: int = InputField(default=512, description="The width of the image")
|
||||
height: int = InputField(default=512, description="The height of the image")
|
||||
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Crop Image")
|
||||
@tags("image", "crop")
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -120,31 +111,31 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Paste Image")
|
||||
@tags("image", "paste")
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
|
||||
# Inputs
|
||||
base_image: Optional[ImageField] = Field(default=None, description="The base image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
||||
}
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
image: ImageField = InputField(description="The image to paste")
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description="The mask to use when pasting",
|
||||
)
|
||||
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = (
|
||||
None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name))
|
||||
)
|
||||
mask = None
|
||||
if self.mask is not None:
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
mask = ImageOps.invert(mask.convert("L"))
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
@@ -172,23 +163,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Mask from Alpha")
|
||||
@tags("image", "mask")
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
@@ -204,28 +191,24 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Multiply Images")
|
||||
@tags("image", "multiply")
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
|
||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
||||
}
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
image2: ImageField = InputField(description="The second image to multiply")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
@@ -252,21 +235,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Extract Image Channel")
|
||||
@tags("image", "channel")
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -292,21 +271,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Convert Image Mode")
|
||||
@tags("image", "convert")
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -329,22 +304,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Blur Image")
|
||||
@tags("image", "blur")
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
||||
# Metadata
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -390,23 +362,19 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Resize Image")
|
||||
@tags("image", "resize")
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to resize")
|
||||
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -434,22 +402,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Scale Image")
|
||||
@tags("image", "scale")
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
scale_factor: float = InputField(
|
||||
default=2.0,
|
||||
gt=0,
|
||||
description="The factor by which to scale the image",
|
||||
)
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -479,28 +447,24 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Lerp Image")
|
||||
@tags("image", "lerp")
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
image_arr = image_arr * (self.max - self.min) + self.min
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@@ -520,25 +484,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Inverse Lerp Image")
|
||||
@tags("image", "ilerp")
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -564,21 +521,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Blur NSFW Image")
|
||||
@tags("image", "nsfw")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_nsfw"] = "img_nsfw"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -615,22 +570,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
|
||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Add Invisible Watermark")
|
||||
@tags("image", "watermark")
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_watermark"] = "img_watermark"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -650,3 +603,334 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Mask Edge")
|
||||
@tags("image", "mask", "inpaint")
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
type: Literal["mask_edge"] = "mask_edge"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
edge_size: int = InputField(description="The size of the edge")
|
||||
edge_blur: int = InputField(description="The amount of blur on the edge")
|
||||
low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
|
||||
high_threshold: int = InputField(
|
||||
description="Second threshold for the hysteresis procedure in Canny edge detection"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||
npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
|
||||
npedge = cv2.Canny(npimg, threshold1=self.low_threshold, threshold2=self.high_threshold)
|
||||
npmask = npgradient + npedge
|
||||
npmask = cv2.dilate(npmask, numpy.ones((3, 3), numpy.uint8), iterations=int(self.edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if self.edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(self.edge_blur))
|
||||
|
||||
new_mask = ImageOps.invert(new_mask)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=new_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Combine Mask")
|
||||
@tags("image", "mask", "multiply")
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
type: Literal["mask_combine"] = "mask_combine"
|
||||
|
||||
# Inputs
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
mask2: ImageField = InputField(description="The second image to combine")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
||||
mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L")
|
||||
|
||||
combined_mask = ImageChops.multiply(mask1, mask2)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=combined_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Color Correct")
|
||||
@tags("image", "color")
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
"""
|
||||
|
||||
type: Literal["color_correct"] = "color_correct"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to color-correct")
|
||||
reference: ImageField = InputField(description="Reference image for color-correction")
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_init_mask = None
|
||||
if self.mask is not None:
|
||||
pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L")
|
||||
|
||||
init_image = context.services.images.get_pil_image(self.reference.image_name)
|
||||
|
||||
result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||
|
||||
# if init_image is None or init_mask is None:
|
||||
# return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
# pil_init_mask = (
|
||||
# init_mask.getchannel("A")
|
||||
# if init_mask.mode == "RGBA"
|
||||
# else init_mask.convert("L")
|
||||
# )
|
||||
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = numpy.asarray(init_image.convert("RGB"), dtype=numpy.uint8)
|
||||
init_a_pixels = numpy.asarray(pil_init_image.getchannel("A"), dtype=numpy.uint8)
|
||||
init_mask_pixels = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = numpy.asarray(result.convert("RGB"), dtype=numpy.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
if np_init_rgb_pixels_masked.size > 0:
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:, :, :] = (
|
||||
(
|
||||
(
|
||||
(np_matched_result[:, :, :].astype(numpy.float32) - gen_means[None, None, :])
|
||||
/ gen_std[None, None, :]
|
||||
)
|
||||
* init_std[None, None, :]
|
||||
+ init_means[None, None, :]
|
||||
)
|
||||
.clip(0, 255)
|
||||
.astype(numpy.uint8)
|
||||
)
|
||||
matched_result = Image.fromarray(np_matched_result, mode="RGB")
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode="RGB")
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if self.mask_blur_radius > 0:
|
||||
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||
nmd = cv2.erode(
|
||||
nm,
|
||||
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
|
||||
iterations=int(self.mask_blur_radius / 2),
|
||||
)
|
||||
pmd = Image.fromarray(nmd, mode="L")
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(self.mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, result.split()[-1])
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=matched_result,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Hue Adjustment")
|
||||
@tags("image", "hue", "hsl")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = numpy.array(pil_image.convert("HSV"))
|
||||
|
||||
# Convert hue from 0..360 to 0..256
|
||||
hue = int(256 * ((self.hue % 360) / 360))
|
||||
|
||||
# Increment each hue and wrap around at 255
|
||||
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Luminosity Adjustment")
|
||||
@tags("image", "luminosity", "hsl")
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
|
||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
luminosity: float = InputField(
|
||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the luminosity (value)
|
||||
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Saturation Adjustment")
|
||||
@tags("image", "saturation", "hsl")
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
|
||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the saturation
|
||||
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
)
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
"lama",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
@@ -33,6 +30,11 @@ INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
def infill_lama(im: Image.Image) -> Image.Image:
|
||||
lama = LaMA()
|
||||
return lama(im)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@@ -95,7 +97,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
@@ -114,21 +116,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@title("Solid Color Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: ColorField = Field(
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
color: ColorField = InputField(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -153,25 +154,23 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Tile Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
# Input
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -194,17 +193,15 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("PatchMatch Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -228,3 +225,34 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("LaMa Infill")
|
||||
@tags("image", "inpaint")
|
||||
class LaMaInfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
type: Literal["infill_lama"] = "infill_lama"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -4,16 +4,34 @@ from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
import torchvision.transforms as T
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@@ -24,57 +42,25 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.model_management import ModelPatcher
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
@@ -82,6 +68,7 @@ def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
@@ -98,6 +85,11 @@ def get_scheduler(
|
||||
**scheduler_extra_config,
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||
if scheduler_class is DPMSolverSDEScheduler:
|
||||
scheduler_config["noise_sampler_seed"] = seed
|
||||
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@@ -106,25 +98,41 @@ def get_scheduler(
|
||||
return scheduler
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
@title("Denoise Latents")
|
||||
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
type: Literal["denoise_latents"] = "denoise_latents"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
description=FieldDescriptions.latents, input=Input.Connection, ui_order=4
|
||||
)
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
@@ -138,33 +146,20 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
@@ -172,13 +167,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@@ -198,7 +194,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
# flip all bits to have noise different from initial
|
||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@@ -231,7 +228,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
@@ -310,110 +306,83 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings():
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
num_inference_steps = steps
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(num_inference_steps, device="cpu")
|
||||
timesteps = scheduler.timesteps.to(device=device)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
# apply denoising_start
|
||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps)))
|
||||
timesteps = timesteps[t_start_idx:]
|
||||
if scheduler.order == 2 and t_start_idx > 0:
|
||||
timesteps = timesteps[1:]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
# save start timestep to apply noise
|
||||
init_timestep = timesteps[:1]
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
# apply denoising_end
|
||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps)))
|
||||
if scheduler.order == 2 and t_end_idx > 0:
|
||||
t_end_idx += 1
|
||||
timesteps = timesteps[:t_end_idx]
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
# calculate step count based on scheduler order
|
||||
num_inference_steps = len(timesteps)
|
||||
if scheduler.order == 2:
|
||||
num_inference_steps += num_inference_steps % 2
|
||||
num_inference_steps = num_inference_steps // 2
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
def prep_mask_tensor(self, mask, context, lantents):
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latent To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
mask_image = context.services.images.get_pil_image(mask.image_name)
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
return 1 - mask_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
else:
|
||||
latents = torch.zeros_like(noise)
|
||||
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask = self.prep_mask_tensor(self.mask, context, latents)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
@@ -432,44 +401,48 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
if mask is not None:
|
||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
latents_shape=latents.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = (
|
||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
@@ -481,32 +454,32 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@title("Latents to Image")
|
||||
@tags("latents", "image", "vae", "l2i")
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latents To Image",
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@@ -584,24 +557,30 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@title("Resize Latents")
|
||||
@tags("latents", "resize")
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
||||
}
|
||||
width: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -623,26 +602,24 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@title("Scale Latents")
|
||||
@tags("latents", "resize")
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
||||
}
|
||||
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -665,25 +642,26 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@title("Image to Latents")
|
||||
@tags("latents", "image", "vae", "i2l")
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The image to encode")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
||||
}
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -746,4 +724,82 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
|
||||
@title("Blend Latents")
|
||||
@tags("latents", "blend")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
|
||||
type: Literal["lblend"] = "lblend"
|
||||
|
||||
# Inputs
|
||||
latents_a: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents_b: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise "Latents to blend must be the same size."
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
|
||||
# blend
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, blended_latents)
|
||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||
|
||||
@@ -2,134 +2,83 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import IntegerOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Add Integers")
|
||||
@tags("math")
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Subtract Integers")
|
||||
@tags("math")
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Multiply Integers")
|
||||
@tags("math")
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Divide Integers")
|
||||
@tags("math")
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@title("Random Integer")
|
||||
@tags("math")
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
||||
}
|
||||
# Inputs
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
||||
@@ -1,30 +1,38 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
class LoRAMetadataField(BaseModelExcludeNull):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
|
||||
class CoreMetadata(BaseModel):
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
created_by: Optional[str] = Field(description="The name of the creator of the image")
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
@@ -40,37 +48,40 @@ class CoreMetadata(BaseModel):
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Union[float, None] = Field(
|
||||
strength: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
class ImageMetadata(BaseModelExcludeNull):
|
||||
"""An image's generation metadata"""
|
||||
|
||||
metadata: Optional[dict] = Field(
|
||||
@@ -85,66 +96,86 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
||||
|
||||
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@title("Metadata Accumulator")
|
||||
@tags("metadata")
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||
|
||||
generation_mode: str = Field(
|
||||
generation_mode: str = InputField(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(
|
||||
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||
width: int = InputField(description="The width parameter")
|
||||
height: int = InputField(description="The height parameter")
|
||||
seed: int = InputField(description="The seed used for noise generation")
|
||||
rand_device: str = InputField(description="The device used for random number generation")
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: int = InputField(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
init_image: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||
positive_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The positive style prompt parameter",
|
||||
)
|
||||
negative_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The negative style prompt parameter",
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
refiner_cfg_scale: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
refiner_steps: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The number of steps used for the refiner",
|
||||
)
|
||||
refiner_scheduler: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The scheduler used for the refiner",
|
||||
)
|
||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_start: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The start value used for refiner denoising",
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
@@ -39,13 +50,11 @@ class VaeField(BaseModel):
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output"] = "model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
@@ -63,24 +72,17 @@ class LoRAModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@title("Main Model")
|
||||
@tags("model")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["main_model_loader"] = "main_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
@@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("LoRA")
|
||||
@tags("lora", "model")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
@@ -262,6 +245,101 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("SDXL LoRA")
|
||||
@tags("sdxl", "lora", "model")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = Field(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||
)
|
||||
clip: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
)
|
||||
clip2: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
@@ -272,29 +350,23 @@ class VAEModelField(BaseModel):
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
# fmt: on
|
||||
# Outputs
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("VAE")
|
||||
@tags("vae", "model")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
type: Literal["vae_loader"] = "vae_loader"
|
||||
|
||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
"type_hints": {"vae_model": "vae_model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, validator
|
||||
import torch
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -61,62 +65,53 @@ Nodes
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
# fmt: on
|
||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
@title("Noise")
|
||||
@tags("latents", "noise")
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use",
|
||||
description=FieldDescriptions.seed,
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
width: int = Field(
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the resulting noise",
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = Field(
|
||||
height: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the resulting noise",
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
use_cpu: bool = Field(
|
||||
use_cpu: bool = InputField(
|
||||
default=True,
|
||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Noise",
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
@@ -132,4 +127,4 @@ class NoiseInvocation(BaseInvocation):
|
||||
)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||
|
||||
@@ -1,37 +1,42 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
from contextlib import ExitStack
|
||||
import inspect
|
||||
import re
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import re
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.util import choose_torch_device
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
from tqdm import tqdm
|
||||
from .model import ClipField
|
||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||
from .compel import CompelOutput
|
||||
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
@@ -51,21 +56,22 @@ ORT_TO_NP_TYPE = {
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
|
||||
|
||||
@title("ONNX Prompt (Raw)")
|
||||
@tags("onnx", "prompt")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
@@ -76,18 +82,14 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
# stack.enter_context(
|
||||
# context.services.model_manager.get_model(
|
||||
# model_name=name,
|
||||
# base_model=self.clip.text_encoder.base_model,
|
||||
# model_type=ModelType.TextualInversion,
|
||||
# )
|
||||
# )
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# print(e)
|
||||
@@ -131,7 +133,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@@ -139,25 +141,48 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# Text to image
|
||||
@title("ONNX Text to Latents")
|
||||
@tags("latents", "inference", "txt2img", "onnx")
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
noise: LatentsField = InputField(
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
ui_type=UIType.Control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
@@ -171,20 +196,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -217,6 +228,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=0, # TODO: refactor this node
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
@@ -246,7 +258,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
@@ -304,26 +316,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# Latent to image
|
||||
@title("ONNX Latents to Image")
|
||||
@tags("latents", "image", "vae", "onnx")
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.denoised_latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -377,89 +391,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
model_name = "stable-diffusion-v1-5"
|
||||
base_model = BaseModelType.StableDiffusion1
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {model_name}!")
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.ONNX,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
@@ -468,22 +406,17 @@ class OnnxModelField(BaseModel):
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@title("ONNX Main Model")
|
||||
@tags("onnx", "model")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
model: OnnxModelField = Field(description="The model to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Onnx Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
model: OnnxModelField = InputField(
|
||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
|
||||
@@ -1,73 +1,61 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
from typing import Literal, Optional
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseOut,
|
||||
CubicEaseInOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseOut,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseOut,
|
||||
SineEaseInOut,
|
||||
SineEaseIn,
|
||||
SineEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseInOut,
|
||||
CubicEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseOut,
|
||||
LinearInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseInOut,
|
||||
QuadEaseOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseOut,
|
||||
SineEaseIn,
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Float Range")
|
||||
@tags("math", "range")
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
||||
}
|
||||
start: float = InputField(default=5, description="The first value of the range")
|
||||
stop: float = InputField(default=10, description="The last value of the range")
|
||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
@@ -108,37 +96,32 @@ EASING_FUNCTIONS_MAP = {
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@title("Step Param Easing")
|
||||
@tags("step", "easing")
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
||||
}
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.prompt import PromptOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""A string output"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = Field(default=None, description="The output string")
|
||||
|
||||
|
||||
class ParamStringInvocation(BaseInvocation):
|
||||
"""A string parameter"""
|
||||
|
||||
type: Literal["param_string"] = "param_string"
|
||||
text: str = Field(default="", description="The string value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
class ParamPromptInvocation(BaseInvocation):
|
||||
"""A prompt input parameter"""
|
||||
|
||||
type: Literal["param_prompt"] = "param_prompt"
|
||||
prompt: str = Field(default="", description="The prompt value")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
||||
return PromptOutput(prompt=self.prompt)
|
||||
479
invokeai/app/invocations/primitives.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
"""
|
||||
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
||||
- primitive nodes
|
||||
- primitive outputs
|
||||
- primitive collection outputs
|
||||
"""
|
||||
|
||||
# region Boolean
|
||||
|
||||
|
||||
class BooleanOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single boolean"""
|
||||
|
||||
type: Literal["boolean_output"] = "boolean_output"
|
||||
value: bool = OutputField(description="The output boolean")
|
||||
|
||||
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of booleans"""
|
||||
|
||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||
|
||||
|
||||
@title("Boolean Primitive")
|
||||
@tags("primitives", "boolean")
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
# Inputs
|
||||
value: bool = InputField(default=False, description="The boolean value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||
return BooleanOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Boolean Primitive Collection")
|
||||
@tags("primitives", "boolean", "collection")
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
type: Literal["boolean_collection"] = "boolean_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[bool] = InputField(
|
||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Integer
|
||||
|
||||
|
||||
class IntegerOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single integer"""
|
||||
|
||||
type: Literal["integer_output"] = "integer_output"
|
||||
value: int = OutputField(description="The output integer")
|
||||
|
||||
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of integers"""
|
||||
|
||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||
|
||||
|
||||
@title("Integer Primitive")
|
||||
@tags("primitives", "integer")
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
|
||||
type: Literal["integer"] = "integer"
|
||||
|
||||
# Inputs
|
||||
value: int = InputField(default=0, description="The integer value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Integer Primitive Collection")
|
||||
@tags("primitives", "integer", "collection")
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
type: Literal["integer_collection"] = "integer_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[int] = InputField(
|
||||
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Float
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single float"""
|
||||
|
||||
type: Literal["float_output"] = "float_output"
|
||||
value: float = OutputField(description="The output float")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
type: Literal["float_collection_output"] = "float_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||
|
||||
|
||||
@title("Float Primitive")
|
||||
@tags("primitives", "float")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
type: Literal["float"] = "float"
|
||||
|
||||
# Inputs
|
||||
value: float = InputField(default=0.0, description="The float value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(value=self.value)
|
||||
|
||||
|
||||
@title("Float Primitive Collection")
|
||||
@tags("primitives", "float", "collection")
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[float] = InputField(
|
||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region String
|
||||
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single string"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
value: str = OutputField(description="The output string")
|
||||
|
||||
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
type: Literal["string_collection_output"] = "string_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||
|
||||
|
||||
@title("String Primitive")
|
||||
@tags("primitives", "string")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
# Inputs
|
||||
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(value=self.value)
|
||||
|
||||
|
||||
@title("String Primitive Collection")
|
||||
@tags("primitives", "string", "collection")
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
type: Literal["string_collection"] = "string_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[str] = InputField(
|
||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = OutputField(description="The output image")
|
||||
width: int = OutputField(description="The width of the image in pixels")
|
||||
height: int = OutputField(description="The height of the image in pixels")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
type: Literal["image_collection_output"] = "image_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||
|
||||
|
||||
@title("Image Primitive")
|
||||
@tags("primitives", "image")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
# Metadata
|
||||
type: Literal["image"] = "image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Primitive Collection")
|
||||
@tags("primitives", "image", "collection")
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ImageField] = InputField(
|
||||
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latents
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single latents tensor"""
|
||||
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
latents: LatentsField = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of latents tensors"""
|
||||
|
||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
|
||||
|
||||
@title("Latents Primitive")
|
||||
@tags("primitives", "latents")
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
|
||||
type: Literal["latents"] = "latents"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@title("Latents Primitive Collection")
|
||||
@tags("primitives", "latents", "collection")
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
|
||||
type: Literal["latents_collection"] = "latents_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[LatentsField] = InputField(
|
||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Color
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ColorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single color"""
|
||||
|
||||
type: Literal["color_output"] = "color_output"
|
||||
color: ColorField = OutputField(description="The output color")
|
||||
|
||||
|
||||
class ColorCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of colors"""
|
||||
|
||||
type: Literal["color_collection_output"] = "color_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||
|
||||
|
||||
@title("Color Primitive")
|
||||
@tags("primitives", "color")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
type: Literal["color"] = "color"
|
||||
|
||||
# Inputs
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ColorOutput:
|
||||
return ColorOutput(color=self.color)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Conditioning
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
type: Literal["conditioning_output"] = "conditioning_output"
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||
|
||||
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
|
||||
@title("Conditioning Primitive")
|
||||
@tags("primitives", "conditioning")
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
type: Literal["conditioning"] = "conditioning"
|
||||
|
||||
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@title("Conditioning Primitive Collection")
|
||||
@tags("primitives", "conditioning", "collection")
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
||||
type: Literal["conditioning_collection"] = "conditioning_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
return ConditioningCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -1,59 +1,28 @@
|
||||
from os.path import exists
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from pydantic import validator
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"prompt",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class PromptCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a collection of prompts"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
||||
|
||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
||||
count: int = Field(description="The size of the prompt collection")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
|
||||
|
||||
|
||||
@title("Dynamic Prompt")
|
||||
@tags("prompt", "collection")
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
||||
}
|
||||
# Inputs
|
||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||
@@ -61,27 +30,26 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@title("Prompts from File")
|
||||
@tags("prompt", "file")
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||
type: Literal["prompt_from_file"] = "prompt_from_file"
|
||||
|
||||
# Inputs
|
||||
file_path: str = Field(description="Path to prompt text file")
|
||||
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
||||
}
|
||||
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
|
||||
pre_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||
)
|
||||
post_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||
)
|
||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
|
||||
@validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
@@ -89,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
raise ValueError(FileNotFoundError)
|
||||
return v
|
||||
|
||||
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
|
||||
def promptsFromFile(
|
||||
self,
|
||||
file_path: str,
|
||||
pre_prompt: Union[str, None],
|
||||
post_prompt: Union[str, None],
|
||||
start_line: int,
|
||||
max_prompts: int,
|
||||
):
|
||||
prompts = []
|
||||
start_line -= 1
|
||||
end_line = start_line + max_prompts
|
||||
@@ -103,8 +78,8 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
break
|
||||
return prompts
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||
)
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
@@ -1,62 +1,55 @@
|
||||
import torch
|
||||
import inspect
|
||||
from tqdm import tqdm
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, validator
|
||||
from typing import Literal
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
|
||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||
from .compel import ConditioningField
|
||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
# fmt: on
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("SDXL Main Model")
|
||||
@tags("model", "sdxl")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Model Loader",
|
||||
"tags": ["model", "loader", "sdxl"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@@ -129,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Model")
|
||||
@tags("model", "sdxl", "refiner")
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "refiner_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@@ -201,506 +191,3 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_sdxl"] = "t2l_sdxl"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
num_inference_steps = self.steps
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype) * scheduler.init_noise_sigma
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
|
||||
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
|
||||
|
||||
# apply denoising_end
|
||||
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
if not context.services.configuration.sequential_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# del noise_pred_uncond
|
||||
# del noise_pred_text
|
||||
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# del noise_pred
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
|
||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["l2l_sdxl"] = "l2l_sdxl"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
# apply denoising_start
|
||||
num_inference_steps = self.steps
|
||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||
|
||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply noise(if provided)
|
||||
if self.noise is not None and timesteps.shape[0] > 0:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||
del noise
|
||||
|
||||
# apply scheduler extra args
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||
|
||||
# apply denoising_end
|
||||
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
if not context.services.configuration.sequential_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# del noise_pred_uncond
|
||||
# del noise_pred_text
|
||||
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# del noise_pred_text
|
||||
# del noise_pred_uncond
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# del noise_pred
|
||||
# import gc
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from realesrgan import RealESRGANer
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[
|
||||
]
|
||||
|
||||
|
||||
@title("Upscale (RealESRGAN)")
|
||||
@tags("esrgan", "upscale")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
type: Literal["esrgan"] = "esrgan"
|
||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -1,31 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
@@ -36,50 +13,6 @@ class ProgressImage(BaseModel):
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
@@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union, Optional
|
||||
from typing import Optional
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
@@ -31,7 +27,6 @@ class BoardImagesServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
@@ -93,10 +88,9 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
self._services.board_image_records.remove_image_from_board(image_name)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import sqlite3
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
@@ -230,7 +229,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
@@ -241,7 +240,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
|
||||
8
invokeai/app/services/config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Init file for InvokeAI configure package
|
||||
"""
|
||||
|
||||
from .invokeai_config import ( # noqa F401
|
||||
InvokeAIAppConfig,
|
||||
get_invokeai_config,
|
||||
)
|
||||
239
invokeai/app/services/config/base.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Base class for the InvokeAI configuration system.
|
||||
It defines a type of pydantic BaseSettings object that
|
||||
is able to read and write from an omegaconf-based config file,
|
||||
with overriding of settings from environment variables and/or
|
||||
the command line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import os
|
||||
import pydoc
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
value = list(value)
|
||||
elif isinstance(value, DictConfig):
|
||||
value = dict(value)
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == Union:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=int_or_float_or_str,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
return str(value)
|
||||
@@ -10,38 +10,49 @@ categories returned by `invokeai --help`. The file looks like this:
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
autoimport_dir: null
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
Memory/Performance:
|
||||
xformers_enabled: false
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_cache_size: 6
|
||||
max_vram_cache_size: 2.7
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 8081
|
||||
port: 9090
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
Features:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
patchmatch: true
|
||||
ignore_missing_core_models: false
|
||||
Paths:
|
||||
autoimport_dir: autoimport
|
||||
lora_dir: null
|
||||
embedding_dir: null
|
||||
controlnet_dir: null
|
||||
conf_path: configs/models.yaml
|
||||
models_dir: models
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
db_dir: databases
|
||||
outdir: /home/lstein/invokeai-main/outputs
|
||||
use_memory_db: false
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
log_format: plain
|
||||
log_level: info
|
||||
Model Cache:
|
||||
ram: 13.5
|
||||
vram: 0.25
|
||||
lazy_offload: true
|
||||
Device:
|
||||
device: auto
|
||||
precision: auto
|
||||
Generation:
|
||||
sequential_guidance: false
|
||||
attention_type: xformers
|
||||
attention_slice_size: auto
|
||||
force_tiled_decode: false
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
@@ -55,24 +66,23 @@ InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||
at initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf.parse_args(argv=['--xformers_enabled'])
|
||||
conf.parse_args(argv=['--log_tokenization'])
|
||||
|
||||
It is also possible to set a value at initialization time. However, if
|
||||
you call parse_args() it may be overwritten.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
conf = InvokeAIAppConfig(log_tokenization=True)
|
||||
conf.parse_args(argv=['--no-log_tokenization'])
|
||||
conf.log_tokenization
|
||||
# False
|
||||
|
||||
|
||||
To avoid this, use `get_config()` to retrieve the application-wide
|
||||
configuration object. This will retain any properties set at object
|
||||
creation time:
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
conf = InvokeAIAppConfig.get_config(log_tokenization=True)
|
||||
conf.parse_args(argv=['--no-log_tokenization'])
|
||||
conf.log_tokenization
|
||||
# True
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
@@ -94,7 +104,7 @@ Typical usage at the top level file:
|
||||
# get global configuration and print its cache size
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.max_cache_size)
|
||||
print(conf.ram_cache_size)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
@@ -102,8 +112,7 @@ Typical usage in a backend module:
|
||||
|
||||
# get global configuration and print its cache size value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.max_cache_size)
|
||||
|
||||
print(conf.ram_cache_size)
|
||||
|
||||
Computed properties:
|
||||
|
||||
@@ -160,208 +169,18 @@ two configs are kept in separate sections of the config file:
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
from pydantic import Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Union, Optional, get_type_hints
|
||||
|
||||
from .base import InvokeAISettings
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt, name))
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf", "cached_root"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"gpu_mem_reserved",
|
||||
"max_loaded_models",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"restore",
|
||||
"root",
|
||||
"nsfw_checker",
|
||||
"cached_root",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
@@ -378,6 +197,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
|
||||
# WEB
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
@@ -385,25 +206,15 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
|
||||
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
# PATHS
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
@@ -413,20 +224,46 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
|
||||
# LOGGING
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
# DEVICE
|
||||
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
@@ -441,7 +278,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
@@ -460,7 +297,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) != cls
|
||||
or type(cls.singleton_config) is not cls
|
||||
or (kwargs and cls.singleton_init != kwargs)
|
||||
):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
@@ -472,15 +309,12 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Path to the runtime root directory
|
||||
"""
|
||||
# we cache value of root to protect against it being '.' and the cwd changing
|
||||
if self.cached_root:
|
||||
root = self.cached_root
|
||||
elif self.root:
|
||||
if self.root:
|
||||
root = Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
root = self.find_root()
|
||||
self.cached_root = root
|
||||
return self.cached_root
|
||||
root = self.find_root().expanduser().absolute()
|
||||
self.root = root # insulate ourselves from relative paths that may change
|
||||
return root
|
||||
|
||||
@property
|
||||
def root_dir(self) -> Path:
|
||||
@@ -547,11 +381,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true"""
|
||||
@@ -567,6 +396,27 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> float:
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
return self.always_use_cpu or self.device == "cpu"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""
|
||||
Return true if enable_xformers is false (reversed logic)
|
||||
and attention type is not set to xformers.
|
||||
"""
|
||||
disabled_in_config = not self.xformers_enabled
|
||||
return disabled_in_config and self.attention_type != "xformers"
|
||||
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""
|
||||
@@ -576,19 +426,19 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
@@ -1,8 +1,8 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
|
||||
from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocation
|
||||
from ..invocations.image import ImageNSFWBlurInvocation
|
||||
from ..invocations.noise import NoiseInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from ..invocations.primitives import IntegerInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
@@ -17,27 +17,27 @@ def create_text_to_image() -> LibraryGraph:
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
"width": ParamIntInvocation(id="width", a=512),
|
||||
"height": ParamIntInvocation(id="height", a=512),
|
||||
"seed": ParamIntInvocation(id="seed", a=-1),
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": TextToLatentsInvocation(id="6"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="a"),
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="a"),
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="a"),
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
@@ -65,9 +65,9 @@ def create_text_to_image() -> LibraryGraph:
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
||||
ExposedNodeInput(node_path="width", field="value", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="value", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ class EventServiceBase:
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
order: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
@@ -46,6 +47,7 @@ class EventServiceBase:
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
order=order,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,26 +3,22 @@
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ..invocations import * # noqa: F401 F403
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
@@ -183,15 +179,9 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["iterate_output"] = "iterate_output"
|
||||
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"item",
|
||||
]
|
||||
}
|
||||
item: Any = OutputField(
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||
)
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@@ -200,8 +190,10 @@ class IterateInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
||||
collection: list[Any] = InputField(
|
||||
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||
)
|
||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
@@ -211,15 +203,9 @@ class IterateInvocation(BaseInvocation):
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["collect_output"] = "collect_output"
|
||||
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"collection",
|
||||
]
|
||||
}
|
||||
collection: list[Any] = OutputField(
|
||||
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||
)
|
||||
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
@@ -227,13 +213,14 @@ class CollectInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["collect"] = "collect"
|
||||
|
||||
item: Any = Field(
|
||||
item: Any = InputField(
|
||||
description="The item to collect (all inputs must be of the same type)",
|
||||
default=None,
|
||||
ui_type=UIType.CollectionItem,
|
||||
title="Collection Item",
|
||||
input=Input.Connection,
|
||||
)
|
||||
collection: list[Any] = Field(
|
||||
description="The collection, will be provided on execution",
|
||||
default_factory=list,
|
||||
collection: list[Any] = InputField(
|
||||
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
@@ -459,7 +446,7 @@ class Graph(BaseModel):
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
if type(node) is not type(new_node):
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
@@ -646,7 +633,7 @@ class Graph(BaseModel):
|
||||
[
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
||||
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
@@ -937,7 +924,7 @@ class GraphExecutionState(BaseModel):
|
||||
None,
|
||||
)
|
||||
|
||||
if next_node_id == None:
|
||||
if next_node_id is None:
|
||||
return None
|
||||
|
||||
# Get all parents of the next node
|
||||
|
||||
@@ -179,7 +179,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
if image_name not in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
|
||||
@@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join(
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
],
|
||||
)
|
||||
)
|
||||
@@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
@@ -200,6 +202,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -222,6 +234,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -264,7 +282,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
SELECT images.metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
@@ -291,7 +309,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
@@ -302,7 +320,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
@@ -313,7 +331,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
@@ -321,6 +339,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@@ -397,7 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
@@ -500,6 +529,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
@@ -515,9 +545,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
is_intermediate,
|
||||
starred
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@@ -529,6 +560,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
starred,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@@ -289,9 +288,10 @@ class ImageService(ImageServiceABC):
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata()
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
@@ -303,7 +303,6 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
@@ -379,10 +378,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
self._services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
self._services.logger.error("Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
@@ -395,10 +394,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image records")
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image files")
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
@@ -412,10 +411,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.delete(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image records")
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image files")
|
||||
self._services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image records and files")
|
||||
|
||||
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
@@ -32,6 +33,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
|
||||
def __init__(
|
||||
@@ -47,6 +49,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.board_images = board_images
|
||||
@@ -61,4 +64,5 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
304
invokeai/app/services/invocation_stats.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
|
||||
Typical output:
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||
|
||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
# {graph_id => NodeLog}
|
||||
_stats: Dict[str, NodeLog]
|
||||
_cache_stats: Dict[str, CacheStats]
|
||||
ram_used: float
|
||||
ram_changed: float
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
:param graph_execution_manager: Graph execution manager for this session
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
"""
|
||||
Reset all statistics for the indicated graph
|
||||
:param graph_execution_state_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self):
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
# {graph_id => NodeLog}
|
||||
self._stats: Dict[str, NodeLog] = {}
|
||||
self._cache_stats: Dict[str, CacheStats] = {}
|
||||
self.ram_used: float = 0.0
|
||||
self.ram_changed: float = 0.0
|
||||
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation
|
||||
collector: "InvocationStatsServiceBase"
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0.0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.ram_used = psutil.Process().memory_info().rss
|
||||
if self.model_manager:
|
||||
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
|
||||
def __exit__(self, *args):
|
||||
"""Called on exit from the context."""
|
||||
ram_used = psutil.Process().memory_info().rss
|
||||
self.collector.update_mem_stats(
|
||||
ram_used=ram_used / GIG,
|
||||
ram_changed=(ram_used - self.ram_used) / GIG,
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
self._stats = {}
|
||||
|
||||
def reset_stats(self, graph_execution_id: str):
|
||||
try:
|
||||
self._stats.pop(graph_execution_id)
|
||||
except KeyError:
|
||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
stats = self._stats[graph_id].nodes[invocation_type]
|
||||
stats.calls += 1
|
||||
stats.time_used += time_used
|
||||
stats.max_vram = max(stats.max_vram, vram_used)
|
||||
|
||||
def log_stats(self):
|
||||
completed = set()
|
||||
errored = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
try:
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
except Exception:
|
||||
errored.add(graph_id)
|
||||
continue
|
||||
|
||||
if not current_graph_state.is_complete():
|
||||
continue
|
||||
|
||||
total_time = 0
|
||||
logger.info(f"Graph stats: {graph_id}")
|
||||
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
||||
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
||||
total_time += stats.time_used
|
||||
|
||||
cache_stats = self._cache_stats[graph_id]
|
||||
hwm = cache_stats.high_watermark / GIG
|
||||
tot = cache_stats.cache_size / GIG
|
||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||
|
||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
||||
logger.info("RAM cache statistics:")
|
||||
logger.info(f" Model cache hits: {cache_stats.hits}")
|
||||
logger.info(f" Model cache misses: {cache_stats.misses}")
|
||||
logger.info(f" Models cached: {cache_stats.in_cache}")
|
||||
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
||||
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
||||
|
||||
completed.add(graph_id)
|
||||
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
|
||||
for graph_id in errored:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
@@ -60,7 +60,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from pydantic import Field
|
||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management import (
|
||||
@@ -21,6 +22,7 @@ from invokeai.backend.model_management import (
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@@ -193,7 +195,7 @@ class ModelManagerServiceBase(ABC):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@@ -275,6 +277,13 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@@ -292,7 +301,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@@ -321,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
@@ -396,7 +405,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
@@ -416,7 +425,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
@@ -429,7 +438,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@@ -478,7 +487,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
@@ -499,6 +508,12 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
@@ -573,9 +588,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
@@ -633,8 +648,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
|
||||
8
invokeai/app/services/models/board_image.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardImage(BaseModelExcludeNull):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
class BoardRecord(BaseModelExcludeNull):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
class ImageRecord(BaseModelExcludeNull):
|
||||
"""Deserialized image record without metadata."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
@@ -38,15 +39,18 @@ class ImageRecord(BaseModel):
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
starred: bool = Field(description="Whether this image is starred.")
|
||||
"""Whether this image is starred."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
- `starred`: change whether the image is starred
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
||||
@@ -58,9 +62,11 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
starred: Optional[StrictBool] = Field(default=None, description="The image's new `starred` state")
|
||||
"""The image's new `starred` state."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
@@ -76,11 +82,15 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
|
||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
|
||||
image_record: ImageRecord,
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
@@ -108,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
starred = image_dict.get("starred", False)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
@@ -121,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
starred=starred,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ..models.exceptions import CanceledException
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_stats import InvocationStatsServiceBase
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
@@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
@@ -83,35 +86,43 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
statistics.log_stats()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
@@ -133,7 +144,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
|
||||
@@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
||||
@@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_basename}/full"
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Union
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
@@ -5,7 +6,7 @@ from PIL import Image
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
|
||||
from einops import rearrange
|
||||
from controlnet_aux.util import HWC3, resize_image
|
||||
from controlnet_aux.util import HWC3
|
||||
|
||||
###################################################################
|
||||
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
|
||||
@@ -232,7 +233,8 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
||||
k0 = float(h) / old_h
|
||||
k1 = float(w) / old_w
|
||||
|
||||
safeint = lambda x: int(np.round(x))
|
||||
def safeint(x: Union[int, float]) -> int:
|
||||
return int(np.round(x))
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
||||
if resize_mode == "fill_resize": # OUTER_FIT
|
||||
|
||||
@@ -18,5 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
rng = np.random.default_rng(seed=0)
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
23
invokeai/app/util/model_exclude_null.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
"""
|
||||
We want to exclude null values from objects that make their way to the client.
|
||||
|
||||
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
|
||||
dict method to do this.
|
||||
|
||||
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
|
||||
"""
|
||||
|
||||
|
||||
class BaseModelExcludeNull(BaseModel):
|
||||
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Override the default dict method to exclude None values in the response
|
||||
"""
|
||||
kwargs.pop("exclude_none", None)
|
||||
return super().dict(*args, exclude_none=True, **kwargs)
|
||||
|
||||
pass
|
||||
@@ -4,9 +4,8 @@ from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
@@ -29,6 +28,7 @@ def stable_diffusion_step_callback(
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
@@ -56,23 +56,50 @@ def stable_diffusion_step_callback(
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
@@ -86,59 +113,6 @@ def stable_diffusion_step_callback(
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
)
|
||||
|
||||
|
||||
def stable_diffusion_xl_step_callback(
|
||||
context: InvocationContext,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
# [ 0.1285, 0.2948, 0.1285],
|
||||
# [ 0.0478, 0.1285, 0.0478],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
order=intermediate_state.order,
|
||||
total_steps=intermediate_state.total_steps,
|
||||
)
|
||||
|
||||