mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-23 02:08:23 -05:00
Compare commits
825 Commits
v4.2.9.dev
...
ryan/flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0c2b13558 | ||
|
|
475e92e305 | ||
|
|
99a12fb604 | ||
|
|
1a12c48f6e | ||
|
|
f1c8fd16b5 | ||
|
|
f44a6a7fb2 | ||
|
|
da780c2243 | ||
|
|
742f6781d5 | ||
|
|
d2ffabf276 | ||
|
|
39e28d5e24 | ||
|
|
3f7c233f4d | ||
|
|
af75a8ea99 | ||
|
|
ddfe57e648 | ||
|
|
7e8dd9e8ed | ||
|
|
50f8d6db1b | ||
|
|
961fbe1ba4 | ||
|
|
7342d18734 | ||
|
|
554a4dc592 | ||
|
|
90e486c976 | ||
|
|
ceb5d50568 | ||
|
|
e4cca62a90 | ||
|
|
008f672e47 | ||
|
|
53ae86068c | ||
|
|
7de3c1943f | ||
|
|
a0f36dea31 | ||
|
|
69a2f8d53d | ||
|
|
3f2a61e0a6 | ||
|
|
f45b925bbf | ||
|
|
d2870a512d | ||
|
|
e1e5f970e6 | ||
|
|
0712684dc9 | ||
|
|
edae8a1617 | ||
|
|
9c1cf3e860 | ||
|
|
b6cef9d440 | ||
|
|
ebb92bee26 | ||
|
|
d6c553ca5e | ||
|
|
8b6512cc90 | ||
|
|
a6b998c125 | ||
|
|
5275782533 | ||
|
|
ede3bd8e64 | ||
|
|
da2583b894 | ||
|
|
9210970130 | ||
|
|
2a022a811c | ||
|
|
1a53e8dc5c | ||
|
|
4e12e23b69 | ||
|
|
fd56b35982 | ||
|
|
71e0abe653 | ||
|
|
56956ccf78 | ||
|
|
6d46d82028 | ||
|
|
3ed29a16a8 | ||
|
|
b67c369bdb | ||
|
|
e774b6879e | ||
|
|
e7d95c3724 | ||
|
|
1b65884dbe | ||
|
|
eff9ddc980 | ||
|
|
400ef8cdc3 | ||
|
|
b0ec3de40a | ||
|
|
b38b8bc90c | ||
|
|
a5ab5e5146 | ||
|
|
61fc30b345 | ||
|
|
46d0ba8ce2 | ||
|
|
5a3e0d76d9 | ||
|
|
5eb919f602 | ||
|
|
2301b388e8 | ||
|
|
dbf13999a0 | ||
|
|
a37592f9f3 | ||
|
|
60d4514fd8 | ||
|
|
9709da901c | ||
|
|
44df59e9e9 | ||
|
|
fbe80ceab2 | ||
|
|
a86822db4d | ||
|
|
f024cb1d05 | ||
|
|
6b2d900b54 | ||
|
|
3d6d5affb5 | ||
|
|
99b683fc1f | ||
|
|
d5cd50c3ea | ||
|
|
d7cde0fc23 | ||
|
|
541605edb4 | ||
|
|
0194344de2 | ||
|
|
34f3cb3116 | ||
|
|
5ab4818eb6 | ||
|
|
60d2541934 | ||
|
|
8d87549ebe | ||
|
|
4cb5854990 | ||
|
|
6f4d3d0395 | ||
|
|
93e9e64b3a | ||
|
|
2bdfc340aa | ||
|
|
2a1bc3e044 | ||
|
|
b4d006d14b | ||
|
|
464603e0ea | ||
|
|
864e471e5a | ||
|
|
670e054fe0 | ||
|
|
0abd81ac80 | ||
|
|
1870daffa1 | ||
|
|
d6d27a82a6 | ||
|
|
ff0d2fcc92 | ||
|
|
a2969816fa | ||
|
|
6b20d1564d | ||
|
|
bf484bc90e | ||
|
|
fc58d34d25 | ||
|
|
c15793b794 | ||
|
|
1e32be827e | ||
|
|
8422908b70 | ||
|
|
d10ff59f9c | ||
|
|
eab1f50a6f | ||
|
|
6e346884e3 | ||
|
|
1c9fd1f19a | ||
|
|
28385d06d1 | ||
|
|
12e6f1be89 | ||
|
|
e1a66e22e9 | ||
|
|
b3569e5c0d | ||
|
|
c64693fffd | ||
|
|
ce9f17726f | ||
|
|
5f62dc6699 | ||
|
|
07cb12eef7 | ||
|
|
9e9f465552 | ||
|
|
e148cc810b | ||
|
|
160f54d1ea | ||
|
|
480856a528 | ||
|
|
97aad2ab2f | ||
|
|
2b93dbd96a | ||
|
|
ce4c79a8d9 | ||
|
|
151b4efd3f | ||
|
|
16806e5d8d | ||
|
|
8e01d295db | ||
|
|
fd00e40ca7 | ||
|
|
029158ef3a | ||
|
|
96b74f4a79 | ||
|
|
b1e85f8b60 | ||
|
|
aa418f0aba | ||
|
|
8b747b022b | ||
|
|
ed4b5dfac3 | ||
|
|
b189937bc9 | ||
|
|
e176e48fa3 | ||
|
|
4931bdace5 | ||
|
|
c3b52a1853 | ||
|
|
b201541cb0 | ||
|
|
ba54a05efd | ||
|
|
6746870591 | ||
|
|
542844c6a3 | ||
|
|
4e5f4dadf2 | ||
|
|
1c15c2cb03 | ||
|
|
a041f1f388 | ||
|
|
d0b62c88c9 | ||
|
|
0fd4dd4513 | ||
|
|
4d3ed34232 | ||
|
|
74de22349d | ||
|
|
18ad271225 | ||
|
|
f92730080c | ||
|
|
f83b500645 | ||
|
|
1349e73a1a | ||
|
|
1fdb702557 | ||
|
|
4df531b7c0 | ||
|
|
a5a077964e | ||
|
|
944719cb9c | ||
|
|
92ae679314 | ||
|
|
771c3210b7 | ||
|
|
517946f66e | ||
|
|
eb09253b4e | ||
|
|
d81cd050ef | ||
|
|
ae5ed18f12 | ||
|
|
9026180533 | ||
|
|
437ea1109b | ||
|
|
95177a7389 | ||
|
|
d01af064f9 | ||
|
|
d50ee14d0b | ||
|
|
096e8deac5 | ||
|
|
e3b6ad7076 | ||
|
|
23c93509e0 | ||
|
|
f5eb6a06b5 | ||
|
|
db99b773bc | ||
|
|
daa0064947 | ||
|
|
ea062ab01a | ||
|
|
0c81a435f4 | ||
|
|
be7254dbf8 | ||
|
|
f49cee976d | ||
|
|
c246fc98b3 | ||
|
|
45e155d392 | ||
|
|
c82e17916f | ||
|
|
d9359bac23 | ||
|
|
ae65f89999 | ||
|
|
dd8b25260d | ||
|
|
4f76f5f848 | ||
|
|
3cdc5d869f | ||
|
|
19aa747b8f | ||
|
|
e20ae31d96 | ||
|
|
09fd415527 | ||
|
|
50768a957e | ||
|
|
3942e2a501 | ||
|
|
1a51842277 | ||
|
|
d001a36e14 | ||
|
|
8c65f60e7d | ||
|
|
d48ce8168e | ||
|
|
a955ab6bee | ||
|
|
81bfd4cc08 | ||
|
|
65f1944a93 | ||
|
|
b68845f43f | ||
|
|
bb994751ee | ||
|
|
f3aad7a494 | ||
|
|
80a69e0867 | ||
|
|
e2f2bdbbc2 | ||
|
|
ecda2b1681 | ||
|
|
d00e006784 | ||
|
|
9a6411f2c8 | ||
|
|
b05b0281af | ||
|
|
fb9bce6636 | ||
|
|
92eebd6aaf | ||
|
|
4484981c97 | ||
|
|
8cff753c81 | ||
|
|
b5681f1657 | ||
|
|
abb74fa664 | ||
|
|
ff88536b4a | ||
|
|
cb20c3b313 | ||
|
|
e8335fe7c4 | ||
|
|
749ff3eb71 | ||
|
|
6877db12c9 | ||
|
|
bbdbe36ada | ||
|
|
fca09d79cc | ||
|
|
719cc12d82 | ||
|
|
b8fed9a554 | ||
|
|
e0ea8b72a6 | ||
|
|
df41564c4c | ||
|
|
42ec07daad | ||
|
|
f33e3d63d5 | ||
|
|
451ee78f31 | ||
|
|
65ea492a75 | ||
|
|
afb35d9717 | ||
|
|
f6624322d8 | ||
|
|
00a4504406 | ||
|
|
2d737f824c | ||
|
|
174c136abc | ||
|
|
eb4dcf4453 | ||
|
|
df6ee189db | ||
|
|
d558aefcc7 | ||
|
|
2adffc84d4 | ||
|
|
5b1035d64c | ||
|
|
da48a5d533 | ||
|
|
f22366a427 | ||
|
|
7def35b1c0 | ||
|
|
ace87948dd | ||
|
|
04555f3916 | ||
|
|
dce1fb0d02 | ||
|
|
1617ee0e6f | ||
|
|
ee94ac3d32 | ||
|
|
10066b349b | ||
|
|
db8084fda1 | ||
|
|
f85536de22 | ||
|
|
7c47e7cfc3 | ||
|
|
37ee1ab35b | ||
|
|
488b682489 | ||
|
|
9601d99c01 | ||
|
|
56aa6a3114 | ||
|
|
4f60cec997 | ||
|
|
e012832386 | ||
|
|
b9ce1cfc16 | ||
|
|
17dd8bb37b | ||
|
|
459d59aac4 | ||
|
|
5cb26fac9f | ||
|
|
3b8c9bb34b | ||
|
|
f9d380107c | ||
|
|
f8b60da938 | ||
|
|
f5fd25d235 | ||
|
|
0097958f62 | ||
|
|
7f8e0c00d9 | ||
|
|
1ef5db035d | ||
|
|
89ff9b8b88 | ||
|
|
bac0ce1e69 | ||
|
|
04f78a99ad | ||
|
|
f4d8809758 | ||
|
|
06dd144c92 | ||
|
|
9b3ec12a3e | ||
|
|
82d50bfcc9 | ||
|
|
7563214a6d | ||
|
|
d99dbdfe7c | ||
|
|
d9fe16bab4 | ||
|
|
db50525442 | ||
|
|
e8190f4389 | ||
|
|
e5e59bf801 | ||
|
|
dd7d4da5e3 | ||
|
|
f394584dff | ||
|
|
1a06b5f1c6 | ||
|
|
9a089495a1 | ||
|
|
c5c8859463 | ||
|
|
6a6efc4574 | ||
|
|
e6bc861ebf | ||
|
|
1499cea82e | ||
|
|
f55282f9bf | ||
|
|
452784068b | ||
|
|
e6b841126b | ||
|
|
31ce4f9283 | ||
|
|
60b3dc846e | ||
|
|
7bb2dc0075 | ||
|
|
7f437adaba | ||
|
|
5a1309cf6e | ||
|
|
f56648be3c | ||
|
|
15735dda6e | ||
|
|
1f1777f7a6 | ||
|
|
167c8ba4ec | ||
|
|
cc7ae42baa | ||
|
|
5fe844c5d9 | ||
|
|
23248dad90 | ||
|
|
caeefdf2ed | ||
|
|
d40d6291a0 | ||
|
|
fd38668f55 | ||
|
|
583654d176 | ||
|
|
59cba2f860 | ||
|
|
772f0b80a1 | ||
|
|
8d8272ee53 | ||
|
|
fef1dddd50 | ||
|
|
725da6e875 | ||
|
|
257b18230a | ||
|
|
a8de6406c5 | ||
|
|
dd2e68bf00 | ||
|
|
7825e325df | ||
|
|
33b3268f83 | ||
|
|
3dbd8212aa | ||
|
|
3694f337bc | ||
|
|
ab77997746 | ||
|
|
5fa7910664 | ||
|
|
8dbb473fde | ||
|
|
4a1240a709 | ||
|
|
664987f2aa | ||
|
|
9e391ec431 | ||
|
|
06944b3ea7 | ||
|
|
f48b949aa8 | ||
|
|
b4166083c5 | ||
|
|
56d53b18f0 | ||
|
|
20961215e7 | ||
|
|
49c75ca381 | ||
|
|
cf6751cc06 | ||
|
|
6cc828b628 | ||
|
|
ddeffb3ef1 | ||
|
|
95b606683f | ||
|
|
0598b89738 | ||
|
|
c2be63a811 | ||
|
|
639304197b | ||
|
|
c4a85cf1bf | ||
|
|
cff80524a8 | ||
|
|
2d1b13bde7 | ||
|
|
220b78d0e7 | ||
|
|
efb97c301e | ||
|
|
cd865347eb | ||
|
|
54ccb9846d | ||
|
|
22a2849683 | ||
|
|
2bae67cfe9 | ||
|
|
de8e8d9f68 | ||
|
|
eced34a72a | ||
|
|
591e8162c1 | ||
|
|
f4998bc308 | ||
|
|
39a49fb585 | ||
|
|
2b9073da36 | ||
|
|
d3aa54f7bd | ||
|
|
f0a959f6fe | ||
|
|
9a5b702013 | ||
|
|
018807d678 | ||
|
|
cf5e8bf4ea | ||
|
|
03ae65863c | ||
|
|
3b7b6d6404 | ||
|
|
e9171c80f6 | ||
|
|
0fd3881b3a | ||
|
|
01ac4c3b3e | ||
|
|
f1fcc98a09 | ||
|
|
b2823569f0 | ||
|
|
3bd98e62de | ||
|
|
318672be53 | ||
|
|
c5a05691fe | ||
|
|
04fcb9e8e6 | ||
|
|
a1534b6503 | ||
|
|
0aa4b1575d | ||
|
|
85eb6ad616 | ||
|
|
9fd2841df0 | ||
|
|
bd23dcd751 | ||
|
|
4d480093d9 | ||
|
|
bb0d2b6ce2 | ||
|
|
0d863a876b | ||
|
|
3fadfd3bbb | ||
|
|
401152f16f | ||
|
|
b69350e9ee | ||
|
|
7b429e0a54 | ||
|
|
3d23fe1fe0 | ||
|
|
d4117f5595 | ||
|
|
2686210887 | ||
|
|
9a804b7986 | ||
|
|
ef0699310d | ||
|
|
afa2da3d2d | ||
|
|
ac1132b5bc | ||
|
|
0276dac38f | ||
|
|
5a3dd83167 | ||
|
|
9f587009cd | ||
|
|
c5ed5e866e | ||
|
|
1f10bc1d63 | ||
|
|
311451b3c9 | ||
|
|
a48e5d9cb0 | ||
|
|
ad92010778 | ||
|
|
01e8988fcc | ||
|
|
d6fec0a0df | ||
|
|
37dc7ee595 | ||
|
|
6d79dc61d2 | ||
|
|
966bc67001 | ||
|
|
4c66a0dcd0 | ||
|
|
50051ee147 | ||
|
|
621f12a1bc | ||
|
|
741b22041d | ||
|
|
f358bb9364 | ||
|
|
65bbc0f00f | ||
|
|
7bf0e554ea | ||
|
|
82b1d8dab8 | ||
|
|
5dda364b2c | ||
|
|
c4e95684b5 | ||
|
|
a0d644ac42 | ||
|
|
37198159c9 | ||
|
|
7170adf3a2 | ||
|
|
cc50578faf | ||
|
|
e80d8b4365 | ||
|
|
30050a23b9 | ||
|
|
706a3c8f2b | ||
|
|
384601898a | ||
|
|
94eb5e638f | ||
|
|
5629c54d55 | ||
|
|
1303396d0e | ||
|
|
bcd5bcf8d7 | ||
|
|
787a4422cb | ||
|
|
5d52633c78 | ||
|
|
1d45444104 | ||
|
|
dd84f2ca64 | ||
|
|
b1c4a91de0 | ||
|
|
187ef3548e | ||
|
|
4abf24a2f6 | ||
|
|
2435ce34be | ||
|
|
e7841824ef | ||
|
|
10596073ac | ||
|
|
405994ee7a | ||
|
|
534d4fa495 | ||
|
|
2aa413d44f | ||
|
|
e6ebb0390e | ||
|
|
5fb9ffca6f | ||
|
|
bd62bab91f | ||
|
|
54edd3f101 | ||
|
|
a889a762b8 | ||
|
|
2163f65be7 | ||
|
|
78471b4bc3 | ||
|
|
af99238a96 | ||
|
|
4e5937036d | ||
|
|
6edc7bbd1d | ||
|
|
db437da726 | ||
|
|
95a9bacd01 | ||
|
|
e95e776733 | ||
|
|
760c7a3076 | ||
|
|
7dd1aec767 | ||
|
|
976b1a5fee | ||
|
|
b79a5e46e2 | ||
|
|
02ddfc5aac | ||
|
|
57f3107dba | ||
|
|
acde3d8952 | ||
|
|
be4983fcbb | ||
|
|
39c8bded65 | ||
|
|
e8f678adde | ||
|
|
e1666c85b7 | ||
|
|
6469cd6e24 | ||
|
|
b6032fd186 | ||
|
|
7a546349e4 | ||
|
|
375c7494b6 | ||
|
|
ac0cc91046 | ||
|
|
918254b600 | ||
|
|
814c3bed09 | ||
|
|
d94ceb25b0 | ||
|
|
619d469fa5 | ||
|
|
02c2308938 | ||
|
|
cf66e6d4ce | ||
|
|
8df40d2d94 | ||
|
|
9942d9a1dc | ||
|
|
835431ad9a | ||
|
|
b5c2b8fdec | ||
|
|
bbcc242280 | ||
|
|
e4ff850ca8 | ||
|
|
9117753a70 | ||
|
|
8095a17f0c | ||
|
|
0d1af8e26e | ||
|
|
b5834002a5 | ||
|
|
f2ba9c5d20 | ||
|
|
2fac67d8a5 | ||
|
|
36e07269e8 | ||
|
|
a35a2a6c8f | ||
|
|
050f258c8e | ||
|
|
4bad6d005a | ||
|
|
22287c9362 | ||
|
|
ee4b27c051 | ||
|
|
93c4454b8d | ||
|
|
5fc2a6a4ad | ||
|
|
c7d2766f2e | ||
|
|
06d76ed362 | ||
|
|
4a1fc2a91f | ||
|
|
0578bf0890 | ||
|
|
e3984cd006 | ||
|
|
f2e197f4e7 | ||
|
|
3cf9a53f88 | ||
|
|
c8d42e64c5 | ||
|
|
82e91afed2 | ||
|
|
13e3fc5e7a | ||
|
|
a32a2c3782 | ||
|
|
73611a7d83 | ||
|
|
7a012e4487 | ||
|
|
8935e6e7c2 | ||
|
|
8af572d502 | ||
|
|
8a0e2d9475 | ||
|
|
6d39a86dbd | ||
|
|
25d16bc779 | ||
|
|
805343f525 | ||
|
|
054c3becc0 | ||
|
|
e317f0ce29 | ||
|
|
a98d92a6c7 | ||
|
|
919f8b1386 | ||
|
|
7cd510a501 | ||
|
|
1b9aeaaea0 | ||
|
|
9b176de649 | ||
|
|
bd63cc0562 | ||
|
|
5580131017 | ||
|
|
5ae4bff91c | ||
|
|
67f06b2f6e | ||
|
|
5be89533f2 | ||
|
|
e54cc241cd | ||
|
|
a17d1f2186 | ||
|
|
23952baaff | ||
|
|
3d286ab8c3 | ||
|
|
2bb64b99e6 | ||
|
|
e26fb33ca7 | ||
|
|
6ab3e9048b | ||
|
|
7a1170f96c | ||
|
|
436ee920bb | ||
|
|
cd09b49e77 | ||
|
|
8a4b4ec4fe | ||
|
|
2b7e6b44ec | ||
|
|
989330af83 | ||
|
|
6c8971748f | ||
|
|
906d70b495 | ||
|
|
a036413f6a | ||
|
|
bb52dccc7a | ||
|
|
d19479941d | ||
|
|
820adec14a | ||
|
|
64efb6b486 | ||
|
|
479063564d | ||
|
|
ba0e4bdc62 | ||
|
|
fc34fec30a | ||
|
|
d69ab7fc86 | ||
|
|
eee0ffd6db | ||
|
|
dcf9e8f2a7 | ||
|
|
8adb0d8fa9 | ||
|
|
3d4c18abf6 | ||
|
|
eba1d054ef | ||
|
|
58b6923bc7 | ||
|
|
ad5c815ade | ||
|
|
d0c0b5e7c4 | ||
|
|
758badb05a | ||
|
|
6bad5bf2d7 | ||
|
|
fbae3fca60 | ||
|
|
fd42c82c83 | ||
|
|
35f9bd57fd | ||
|
|
90f7e4851e | ||
|
|
4ec45a22c7 | ||
|
|
c2b746a3e3 | ||
|
|
2c5e76aa8b | ||
|
|
7ea21370b2 | ||
|
|
ae5e7845bb | ||
|
|
f96a83eecf | ||
|
|
9ce74d8eff | ||
|
|
59ff96a085 | ||
|
|
b82c8d87a3 | ||
|
|
513f95e221 | ||
|
|
34729f7703 | ||
|
|
433b9d6380 | ||
|
|
0cbc684cb8 | ||
|
|
56f5698fc6 | ||
|
|
6e4dc2a69a | ||
|
|
137e9aa820 | ||
|
|
13e8710de9 | ||
|
|
767337fb8e | ||
|
|
d4a0e7899b | ||
|
|
181f54afd3 | ||
|
|
7900a7e2c0 | ||
|
|
ffb9b94719 | ||
|
|
115d938e8e | ||
|
|
53b6959bd5 | ||
|
|
184baaf579 | ||
|
|
eeaa17fbee | ||
|
|
beb4d73f04 | ||
|
|
8c9472cf4e | ||
|
|
ebaa6769b0 | ||
|
|
74de066363 | ||
|
|
148ca3b7d8 | ||
|
|
05ca8951a6 | ||
|
|
95b94a2aa7 | ||
|
|
8661152a73 | ||
|
|
145775021d | ||
|
|
2fd9575cd3 | ||
|
|
749cdcc39e | ||
|
|
9fc4008bfc | ||
|
|
f80127772e | ||
|
|
37b02ba467 | ||
|
|
971da20198 | ||
|
|
f55711c14b | ||
|
|
2f6e4c4a4a | ||
|
|
a0fc840835 | ||
|
|
b65866cb2e | ||
|
|
dffa0bb2fe | ||
|
|
8e56452df8 | ||
|
|
839e24e597 | ||
|
|
44c68f8551 | ||
|
|
5b17bbaac2 | ||
|
|
a9ec37ea79 | ||
|
|
8ed4351a9a | ||
|
|
c7b88219d3 | ||
|
|
8189af0f41 | ||
|
|
083b7d99c8 | ||
|
|
682c2f5c75 | ||
|
|
e56b5e6966 | ||
|
|
5a8fb2af90 | ||
|
|
8d08d456b6 | ||
|
|
a6c2497b35 | ||
|
|
0fcd203b6c | ||
|
|
e91562c245 | ||
|
|
9a0a48a939 | ||
|
|
c28224d574 | ||
|
|
a2840d31bd | ||
|
|
847d1c534c | ||
|
|
dc51374601 | ||
|
|
9680bd61fe | ||
|
|
fdb27d836d | ||
|
|
4d0567823a | ||
|
|
d0cfe632c9 | ||
|
|
03809763a6 | ||
|
|
41ff92592c | ||
|
|
3c754032c9 | ||
|
|
92a1d41eac | ||
|
|
8a0f723b28 | ||
|
|
f5474f18d6 | ||
|
|
2c729946a2 | ||
|
|
e7933cdae1 | ||
|
|
a012cc7041 | ||
|
|
fc2bb5014c | ||
|
|
002fddbf6e | ||
|
|
5d1b6452b0 | ||
|
|
1ea31f6952 | ||
|
|
b19bbc9212 | ||
|
|
16ce3da31f | ||
|
|
91bf5ac9a2 | ||
|
|
9d51882192 | ||
|
|
ac99d61e17 | ||
|
|
b21c28e8fe | ||
|
|
361d3383fc | ||
|
|
54ff94ec38 | ||
|
|
07beb170be | ||
|
|
eafa536c56 | ||
|
|
abdb5abbc1 | ||
|
|
a1dbf426ec | ||
|
|
30ba131704 | ||
|
|
e3f0fb539e | ||
|
|
d6667c773b | ||
|
|
3bd180882c | ||
|
|
1bb7f40b0a | ||
|
|
93d1140a31 | ||
|
|
4235885d47 | ||
|
|
6dc8f5b42e | ||
|
|
bf8d2250ca | ||
|
|
1b2d045be1 | ||
|
|
04df9f5873 | ||
|
|
849b775e55 | ||
|
|
728e21b5ae | ||
|
|
d3a183fe1d | ||
|
|
9ab9d0948f | ||
|
|
7bb6f18175 | ||
|
|
ac0f93f2c2 | ||
|
|
8a75b1411a | ||
|
|
0d552d0ba6 | ||
|
|
6ee0064ce0 | ||
|
|
5c6cd1e897 | ||
|
|
5fcaae39df | ||
|
|
7899c0ef78 | ||
|
|
543af856de | ||
|
|
3e21106336 | ||
|
|
9295985082 | ||
|
|
3ccd58af50 | ||
|
|
3f56c93b8c | ||
|
|
1311276a27 | ||
|
|
327788b1d6 | ||
|
|
1c6015ca73 | ||
|
|
4eaedbb981 | ||
|
|
2c52b77187 | ||
|
|
70527bf931 | ||
|
|
2911de8d7b | ||
|
|
62037ce577 | ||
|
|
e5bff7646a | ||
|
|
ce4b1f7f8d | ||
|
|
09bf3e7d29 | ||
|
|
18d61c2408 | ||
|
|
efac5c8f06 | ||
|
|
dd9f71203d | ||
|
|
3b51509f18 | ||
|
|
324033bdf8 | ||
|
|
d5c32dc2e7 | ||
|
|
b8c8276645 | ||
|
|
c6bf9193e2 | ||
|
|
17911ecf64 | ||
|
|
13bb45934c | ||
|
|
54ba852e71 | ||
|
|
bc85ef6e65 | ||
|
|
856b0f81d5 | ||
|
|
060fe11663 | ||
|
|
9dab54c1ed | ||
|
|
0f7a422153 | ||
|
|
058bf94c93 | ||
|
|
1a0600772f | ||
|
|
d54c18f8c3 | ||
|
|
5fc0bc5136 | ||
|
|
6f0a2d1104 | ||
|
|
9be3e0050d | ||
|
|
11596e45d1 | ||
|
|
ca3913a3c8 | ||
|
|
a6c900ef83 | ||
|
|
209f9e26a0 | ||
|
|
f9eb25b861 | ||
|
|
a3a5e81fdb | ||
|
|
0d73d9dfd3 | ||
|
|
7cdea43a37 | ||
|
|
638d16ce6e | ||
|
|
9a860dbab5 | ||
|
|
5c2a48bba8 | ||
|
|
05338bdba3 | ||
|
|
b32eeada1b | ||
|
|
acc1fefa77 | ||
|
|
a850ffa537 | ||
|
|
2bcb53fe03 | ||
|
|
94fc73ed95 | ||
|
|
df9f998671 | ||
|
|
be3ad43a07 | ||
|
|
5aa155c39f | ||
|
|
c21a21c2aa | ||
|
|
91bcdc10eb | ||
|
|
f18c8e2239 | ||
|
|
2db7608401 | ||
|
|
506632206c | ||
|
|
234a1b6571 | ||
|
|
c9d45d864f | ||
|
|
c0177516f2 | ||
|
|
accf2b5831 | ||
|
|
2f14f83a9a | ||
|
|
262968d0c9 | ||
|
|
244ac735af | ||
|
|
b919bcfc8c | ||
|
|
c21e44cf6b | ||
|
|
593ff0be75 | ||
|
|
6fd042df96 | ||
|
|
c3e1cf7230 | ||
|
|
5b3d86ab14 | ||
|
|
5d4bbbd806 | ||
|
|
cfc6d9e439 | ||
|
|
d10954f47a | ||
|
|
c3e1198448 | ||
|
|
fe9f042111 | ||
|
|
32e86ba72d | ||
|
|
28cd39d152 | ||
|
|
25f3e25555 | ||
|
|
699fbb4e55 | ||
|
|
5fa93de8c4 | ||
|
|
74e976aae4 | ||
|
|
dd829e9d6a | ||
|
|
56bca03fbe | ||
|
|
d0572730a8 | ||
|
|
eb816936ed | ||
|
|
e1b9cac1df | ||
|
|
d927b631c5 | ||
|
|
17dc5d98d1 | ||
|
|
cda086093d | ||
|
|
bda579577c | ||
|
|
a16b555d47 | ||
|
|
6667c39c73 | ||
|
|
5219ac12a6 | ||
|
|
445f813fb9 | ||
|
|
87f9e59cfb | ||
|
|
8b03b39aa8 | ||
|
|
e59b6bb971 | ||
|
|
24a7ed467c | ||
|
|
f01f1033ac | ||
|
|
d35f515413 | ||
|
|
125b459e56 | ||
|
|
33edee1ba6 | ||
|
|
d20335dabc | ||
|
|
d10d258213 | ||
|
|
d57ba1ed8b | ||
|
|
2d0e34e57b | ||
|
|
a005d06255 | ||
|
|
a301ef5a5a | ||
|
|
9422df2737 | ||
|
|
6dabe4d3ca | ||
|
|
00e4652d30 | ||
|
|
b6434c5318 | ||
|
|
3f7f9f8d61 | ||
|
|
f3bb592544 | ||
|
|
69f080fb75 | ||
|
|
04272a7cc8 | ||
|
|
8d35af946e | ||
|
|
24065ec6b6 | ||
|
|
627b0bf644 | ||
|
|
b43da46b82 | ||
|
|
4255a01c64 | ||
|
|
23adbd4002 | ||
|
|
fb5a24fcc6 | ||
|
|
cfdd5a1900 | ||
|
|
2313f326df | ||
|
|
2e092a2313 | ||
|
|
763ef06c18 | ||
|
|
8292f6cd42 | ||
|
|
278bba499e | ||
|
|
dd99ed28e0 | ||
|
|
9a8aca69bf | ||
|
|
7ad62512eb | ||
|
|
bd466661ec | ||
|
|
7ebb509d05 | ||
|
|
0aa13c046c | ||
|
|
a7a33d73f5 | ||
|
|
ffa39857d3 | ||
|
|
e85c3bc465 | ||
|
|
8185ba7054 | ||
|
|
d501865bec | ||
|
|
d62310bb5f | ||
|
|
1835bff196 |
37
.github/workflows/build-container.yml
vendored
37
.github/workflows/build-container.yml
vendored
@@ -13,6 +13,12 @@ on:
|
|||||||
tags:
|
tags:
|
||||||
- 'v*.*.*'
|
- 'v*.*.*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
push-to-registry:
|
||||||
|
description: Push the built image to the container registry
|
||||||
|
required: false
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@@ -50,16 +56,15 @@ jobs:
|
|||||||
df -h
|
df -h
|
||||||
|
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
uses: docker/metadata-action@v4
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
images: |
|
images: |
|
||||||
ghcr.io/${{ github.repository }}
|
ghcr.io/${{ github.repository }}
|
||||||
${{ env.DOCKERHUB_REPOSITORY }}
|
|
||||||
tags: |
|
tags: |
|
||||||
type=ref,event=branch
|
type=ref,event=branch
|
||||||
type=ref,event=tag
|
type=ref,event=tag
|
||||||
@@ -72,49 +77,33 @@ jobs:
|
|||||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||||
|
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v2
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
with:
|
with:
|
||||||
platforms: ${{ env.PLATFORMS }}
|
platforms: ${{ env.PLATFORMS }}
|
||||||
|
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.repository_owner }}
|
username: ${{ github.repository_owner }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
# - name: Login to Docker Hub
|
|
||||||
# if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
|
|
||||||
# uses: docker/login-action@v2
|
|
||||||
# with:
|
|
||||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Build container
|
- name: Build container
|
||||||
timeout-minutes: 40
|
timeout-minutes: 40
|
||||||
id: docker_build
|
id: docker_build
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: docker/Dockerfile
|
file: docker/Dockerfile
|
||||||
platforms: ${{ env.PLATFORMS }}
|
platforms: ${{ env.PLATFORMS }}
|
||||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
cache-from: |
|
cache-from: |
|
||||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||||
|
|
||||||
# - name: Docker Hub Description
|
|
||||||
# if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
|
|
||||||
# uses: peter-evans/dockerhub-description@v3
|
|
||||||
# with:
|
|
||||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
# repository: ${{ vars.DOCKERHUB_REPOSITORY }}
|
|
||||||
# short-description: ${{ github.event.repository.description }}
|
|
||||||
|
|||||||
@@ -197,6 +197,22 @@ tips to reduce the problem:
|
|||||||
|
|
||||||
This should be sufficient to generate larger images up to about 1280x1280.
|
This should be sufficient to generate larger images up to about 1280x1280.
|
||||||
|
|
||||||
|
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||||
|
|
||||||
|
The difference between diffusers models (a folder containing multiple
|
||||||
|
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||||
|
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||||
|
incrementally, while checkpoint models must be loaded all at
|
||||||
|
once. With very large models, or systems with limited RAM, you may
|
||||||
|
experience slowdowns and other memory-related issues when loading
|
||||||
|
checkpoint models.
|
||||||
|
|
||||||
|
To solve this, go to the Model Manager tab (the cube), select the
|
||||||
|
checkpoint model that's giving you trouble, and press the "Convert"
|
||||||
|
button in the upper right of your browser window. This will conver the
|
||||||
|
checkpoint into a diffusers model, after which loading should be
|
||||||
|
faster and less memory-intensive.
|
||||||
|
|
||||||
## Memory Leak (Linux)
|
## Memory Leak (Linux)
|
||||||
|
|
||||||
If you notice a memory leak, it could be caused to memory fragmentation as models are loaded and/or moved from CPU to GPU.
|
If you notice a memory leak, it could be caused to memory fragmentation as models are loaded and/or moved from CPU to GPU.
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from enum import Enum
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
@@ -17,6 +19,7 @@ from starlette.exceptions import HTTPException
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.api.dependencies import ApiDependencies
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
@@ -31,6 +34,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
@@ -50,6 +54,13 @@ class ModelsList(BaseModel):
|
|||||||
model_config = ConfigDict(use_enum_values=True)
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheType(str, Enum):
|
||||||
|
"""Cache type - one of vram or ram."""
|
||||||
|
|
||||||
|
RAM = "RAM"
|
||||||
|
VRAM = "VRAM"
|
||||||
|
|
||||||
|
|
||||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||||
"""Add a cover image URL to a model configuration."""
|
"""Add a cover image URL to a model configuration."""
|
||||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||||
@@ -797,3 +808,83 @@ async def get_starter_models() -> list[StarterModel]:
|
|||||||
model.dependencies = missing_deps
|
model.dependencies = missing_deps
|
||||||
|
|
||||||
return starter_models
|
return starter_models
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/model_cache",
|
||||||
|
operation_id="get_cache_size",
|
||||||
|
response_model=float,
|
||||||
|
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||||
|
)
|
||||||
|
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||||
|
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||||
|
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||||
|
value = 0.0
|
||||||
|
if cache_type == CacheType.RAM:
|
||||||
|
value = cache.max_cache_size
|
||||||
|
elif cache_type == CacheType.VRAM:
|
||||||
|
value = cache.max_vram_cache_size
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.put(
|
||||||
|
"/model_cache",
|
||||||
|
operation_id="set_cache_size",
|
||||||
|
response_model=float,
|
||||||
|
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||||
|
)
|
||||||
|
async def set_cache_size(
|
||||||
|
value: float = Query(description="The new value for the maximum cache size"),
|
||||||
|
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||||
|
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||||
|
) -> float:
|
||||||
|
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||||
|
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||||
|
app_config = get_config()
|
||||||
|
# Record initial state.
|
||||||
|
vram_old = app_config.vram
|
||||||
|
ram_old = app_config.ram
|
||||||
|
|
||||||
|
# Prepare target state.
|
||||||
|
vram_new = vram_old
|
||||||
|
ram_new = ram_old
|
||||||
|
if cache_type == CacheType.RAM:
|
||||||
|
ram_new = value
|
||||||
|
elif cache_type == CacheType.VRAM:
|
||||||
|
vram_new = value
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected {cache_type=}.")
|
||||||
|
|
||||||
|
config_path = app_config.config_file_path
|
||||||
|
new_config_path = config_path.with_suffix(".yaml.new")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to apply the target state.
|
||||||
|
cache.max_vram_cache_size = vram_new
|
||||||
|
cache.max_cache_size = ram_new
|
||||||
|
app_config.ram = ram_new
|
||||||
|
app_config.vram = vram_new
|
||||||
|
if persist:
|
||||||
|
app_config.write_file(new_config_path)
|
||||||
|
shutil.move(new_config_path, config_path)
|
||||||
|
except Exception as e:
|
||||||
|
# If there was a failure, restore the initial state.
|
||||||
|
cache.max_cache_size = ram_old
|
||||||
|
cache.max_vram_cache_size = vram_old
|
||||||
|
app_config.ram = ram_old
|
||||||
|
app_config.vram = vram_old
|
||||||
|
|
||||||
|
raise RuntimeError("Failed to update cache size") from e
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/stats",
|
||||||
|
operation_id="get_stats",
|
||||||
|
response_model=Optional[CacheStats],
|
||||||
|
summary="Get model manager RAM cache performance statistics.",
|
||||||
|
)
|
||||||
|
async def get_stats() -> Optional[CacheStats]:
|
||||||
|
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||||
|
|
||||||
|
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
|
CancelByDestinationResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
PruneResult,
|
PruneResult,
|
||||||
@@ -105,6 +106,21 @@ async def cancel_by_batch_ids(
|
|||||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/cancel_by_destination",
|
||||||
|
operation_id="cancel_by_destination",
|
||||||
|
responses={200: {"model": CancelByBatchIDsResult}},
|
||||||
|
)
|
||||||
|
async def cancel_by_destination(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||||
|
) -> CancelByDestinationResult:
|
||||||
|
"""Immediately cancels all queue items with the given origin"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||||
|
queue_id=queue_id, destination=destination
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@session_queue_router.put(
|
@session_queue_router.put(
|
||||||
"/{queue_id}/clear",
|
"/{queue_id}/clear",
|
||||||
operation_id="clear",
|
operation_id="clear",
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
@@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
|
|||||||
version: str = Field(
|
version: str = Field(
|
||||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||||
)
|
)
|
||||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
|
||||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
if title := model_class.UIConfig.title:
|
||||||
if uiconfig is not None:
|
schema["title"] = title
|
||||||
if uiconfig.title is not None:
|
if tags := model_class.UIConfig.tags:
|
||||||
schema["title"] = uiconfig.title
|
schema["tags"] = tags
|
||||||
if uiconfig.tags is not None:
|
if category := model_class.UIConfig.category:
|
||||||
schema["tags"] = uiconfig.tags
|
schema["category"] = category
|
||||||
if uiconfig.category is not None:
|
if node_pack := model_class.UIConfig.node_pack:
|
||||||
schema["category"] = uiconfig.category
|
schema["node_pack"] = node_pack
|
||||||
if uiconfig.node_pack is not None:
|
schema["classification"] = model_class.UIConfig.classification
|
||||||
schema["node_pack"] = uiconfig.node_pack
|
schema["version"] = model_class.UIConfig.version
|
||||||
schema["classification"] = uiconfig.classification
|
|
||||||
schema["version"] = uiconfig.version
|
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
schema["class"] = "invocation"
|
schema["class"] = "invocation"
|
||||||
@@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||||
)
|
)
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[UIConfigBase]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
@@ -441,30 +438,25 @@ def invocation(
|
|||||||
validate_fields(cls.model_fields, invocation_type)
|
validate_fields(cls.model_fields, invocation_type)
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
uiconfig: dict[str, Any] = {}
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
uiconfig["title"] = title
|
||||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
uiconfig["tags"] = tags
|
||||||
cls.UIConfig.title = title
|
uiconfig["category"] = category
|
||||||
cls.UIConfig.tags = tags
|
uiconfig["classification"] = classification
|
||||||
cls.UIConfig.category = category
|
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||||
cls.UIConfig.classification = classification
|
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||||
|
|
||||||
# Grab the node pack's name from the module name, if it's a custom node
|
|
||||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
|
||||||
if is_custom_node:
|
|
||||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
|
||||||
else:
|
|
||||||
cls.UIConfig.node_pack = None
|
|
||||||
|
|
||||||
if version is not None:
|
if version is not None:
|
||||||
try:
|
try:
|
||||||
semver.Version.parse(version)
|
semver.Version.parse(version)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
cls.UIConfig.version = version
|
uiconfig["version"] = version
|
||||||
else:
|
else:
|
||||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||||
cls.UIConfig.version = "1.0.0"
|
uiconfig["version"] = "1.0.0"
|
||||||
|
|
||||||
|
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||||
|
|
||||||
if use_cache is not None:
|
if use_cache is not None:
|
||||||
cls.model_fields["use_cache"].default = use_cache
|
cls.model_fields["use_cache"].default = use_cache
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ from invokeai.app.invocations.model import CLIPField
|
|||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
@@ -82,9 +83,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(
|
LoraPatcher.apply_lora_patches(
|
||||||
text_encoder,
|
model=text_encoder,
|
||||||
loras=_lora_loader(),
|
patches=_lora_loader(),
|
||||||
|
prefix="lora_te_",
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
@@ -177,9 +179,9 @@ class SDXLPromptInvocationBase:
|
|||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(
|
LoraPatcher.apply_lora_patches(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
patches=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -36,7 +36,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||||
@@ -185,7 +186,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.mask,
|
description=FieldDescriptions.denoise_mask,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=8,
|
ui_order=8,
|
||||||
)
|
)
|
||||||
@@ -979,9 +980,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
LoraPatcher.apply_lora_patches(
|
||||||
unet,
|
model=unet,
|
||||||
loras=_lora_loader(),
|
patches=_lora_loader(),
|
||||||
|
prefix="lora_unet_",
|
||||||
cached_weights=cached_weights,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class FieldDescriptions:
|
|||||||
)
|
)
|
||||||
num_1 = "The first number"
|
num_1 = "The first number"
|
||||||
num_2 = "The second number"
|
num_2 = "The second number"
|
||||||
mask = "The mask to use for the operation"
|
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||||
board = "The board to save the image to"
|
board = "The board to save the image to"
|
||||||
image = "The image to process"
|
image = "The image to process"
|
||||||
tile_size = "Tile size"
|
tile_size = "Tile size"
|
||||||
|
|||||||
267
invokeai/app/invocations/flux_denoise.py
Normal file
267
invokeai/app/invocations/flux_denoise.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
from typing import Callable, Iterator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as tv_transforms
|
||||||
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
FluxConditioningField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import TransformerField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.denoise import denoise
|
||||||
|
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.sampling_utils import (
|
||||||
|
clip_timestep_schedule,
|
||||||
|
generate_img_ids,
|
||||||
|
get_noise,
|
||||||
|
get_schedule,
|
||||||
|
pack,
|
||||||
|
unpack,
|
||||||
|
)
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_denoise",
|
||||||
|
title="FLUX Denoise",
|
||||||
|
tags=["image", "flux"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Run denoising process with a FLUX transformer model."""
|
||||||
|
|
||||||
|
# If latents is provided, this means we are doing image-to-image.
|
||||||
|
latents: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||||
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.denoise_mask,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
denoising_start: float = InputField(
|
||||||
|
default=0.0,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description=FieldDescriptions.denoising_start,
|
||||||
|
)
|
||||||
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
|
transformer: TransformerField = InputField(
|
||||||
|
description=FieldDescriptions.flux_model,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="Transformer",
|
||||||
|
)
|
||||||
|
positive_text_conditioning: FluxConditioningField = InputField(
|
||||||
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
|
)
|
||||||
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||||
|
num_steps: int = InputField(
|
||||||
|
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||||
|
)
|
||||||
|
guidance: float = InputField(
|
||||||
|
default=4.0,
|
||||||
|
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||||
|
)
|
||||||
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = self._run_diffusion(context)
|
||||||
|
latents = latents.detach().to("cpu")
|
||||||
|
|
||||||
|
name = context.tensors.save(tensor=latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
def _run_diffusion(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
):
|
||||||
|
inference_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Load the conditioning data.
|
||||||
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
|
assert len(cond_data.conditionings) == 1
|
||||||
|
flux_conditioning = cond_data.conditionings[0]
|
||||||
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||||
|
t5_embeddings = flux_conditioning.t5_embeds
|
||||||
|
clip_embeddings = flux_conditioning.clip_embeds
|
||||||
|
|
||||||
|
# Load the input latents, if provided.
|
||||||
|
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||||
|
if init_latents is not None:
|
||||||
|
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||||
|
|
||||||
|
# Prepare input noise.
|
||||||
|
noise = get_noise(
|
||||||
|
num_samples=1,
|
||||||
|
height=self.height,
|
||||||
|
width=self.width,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
|
dtype=inference_dtype,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
|
# Calculate the timestep schedule.
|
||||||
|
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||||
|
timesteps = get_schedule(
|
||||||
|
num_steps=self.num_steps,
|
||||||
|
image_seq_len=image_seq_len,
|
||||||
|
shift=not is_schnell,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||||
|
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||||
|
|
||||||
|
# Prepare input latent image.
|
||||||
|
if init_latents is not None:
|
||||||
|
# If init_latents is provided, we are doing image-to-image.
|
||||||
|
|
||||||
|
if is_schnell:
|
||||||
|
context.logger.warning(
|
||||||
|
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
||||||
|
"to be poor. Consider using a FLUX dev model instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||||
|
t_0 = timesteps[0]
|
||||||
|
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||||
|
else:
|
||||||
|
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||||
|
if self.denoising_start > 1e-5:
|
||||||
|
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||||
|
|
||||||
|
x = noise
|
||||||
|
|
||||||
|
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||||
|
# denoising steps.
|
||||||
|
if len(timesteps) <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||||
|
|
||||||
|
b, _c, h, w = x.shape
|
||||||
|
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||||
|
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||||
|
|
||||||
|
# Pack all latent tensors.
|
||||||
|
init_latents = pack(init_latents) if init_latents is not None else None
|
||||||
|
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||||
|
noise = pack(noise)
|
||||||
|
x = pack(x)
|
||||||
|
|
||||||
|
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||||
|
assert image_seq_len == x.shape[1]
|
||||||
|
|
||||||
|
# Prepare inpaint extension.
|
||||||
|
inpaint_extension: InpaintExtension | None = None
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
assert init_latents is not None
|
||||||
|
inpaint_extension = InpaintExtension(
|
||||||
|
init_latents=init_latents,
|
||||||
|
inpaint_mask=inpaint_mask,
|
||||||
|
noise=noise,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||||
|
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||||
|
LoraPatcher.apply_lora_patches(
|
||||||
|
model=transformer,
|
||||||
|
patches=self._lora_iterator(context),
|
||||||
|
prefix="",
|
||||||
|
cached_weights=cached_weights,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
|
x = denoise(
|
||||||
|
model=transformer,
|
||||||
|
img=x,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=t5_embeddings,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
vec=clip_embeddings,
|
||||||
|
timesteps=timesteps,
|
||||||
|
step_callback=self._build_step_callback(context),
|
||||||
|
guidance=self.guidance,
|
||||||
|
inpaint_extension=inpaint_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = unpack(x.float(), self.height, self.width)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||||
|
"""Prepare the inpaint mask.
|
||||||
|
|
||||||
|
- Loads the mask
|
||||||
|
- Resizes if necessary
|
||||||
|
- Casts to same device/dtype as latents
|
||||||
|
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||||
|
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||||
|
device, and dtype for the inpaint mask.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor | None: Inpaint mask.
|
||||||
|
"""
|
||||||
|
if self.denoise_mask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||||
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
mask = tv_resize(
|
||||||
|
img=mask,
|
||||||
|
size=[latent_height, latent_width],
|
||||||
|
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||||
|
antialias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||||
|
# `latents`.
|
||||||
|
return mask.expand_as(latents)
|
||||||
|
|
||||||
|
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
|
for lora in self.transformer.loras:
|
||||||
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
|
yield (lora_info.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
|
||||||
|
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||||
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||||
|
context.util.flux_step_callback(state)
|
||||||
|
|
||||||
|
return step_callback
|
||||||
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
53
invokeai/app/invocations/flux_lora_loader.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
|
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("flux_lora_loader_output")
|
||||||
|
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||||
|
"""FLUX LoRA Loader Output"""
|
||||||
|
|
||||||
|
transformer: TransformerField = OutputField(
|
||||||
|
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_lora_loader",
|
||||||
|
title="FLUX LoRA",
|
||||||
|
tags=["lora", "model", "flux"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||||
|
"""Apply a LoRA model to a FLUX transformer."""
|
||||||
|
|
||||||
|
lora: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||||
|
)
|
||||||
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
|
transformer: TransformerField = InputField(
|
||||||
|
description=FieldDescriptions.transformer,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="FLUX Transformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||||
|
lora_key = self.lora.key
|
||||||
|
|
||||||
|
if not context.models.exists(lora_key):
|
||||||
|
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
|
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||||
|
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||||
|
|
||||||
|
transformer = self.transformer.model_copy(deep=True)
|
||||||
|
transformer.loras.append(
|
||||||
|
LoRAField(
|
||||||
|
lora=self.lora,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return FluxLoRALoaderOutput(transformer=transformer)
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
||||||
from invokeai.app.invocations.fields import (
|
|
||||||
FieldDescriptions,
|
|
||||||
FluxConditioningField,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
WithBoard,
|
|
||||||
WithMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
||||||
from invokeai.backend.flux.model import Flux
|
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
|
||||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"flux_text_to_image",
|
|
||||||
title="FLUX Text to Image",
|
|
||||||
tags=["image", "flux"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
classification=Classification.Prototype,
|
|
||||||
)
|
|
||||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
||||||
"""Text-to-image generation using a FLUX model."""
|
|
||||||
|
|
||||||
transformer: TransformerField = InputField(
|
|
||||||
description=FieldDescriptions.flux_model,
|
|
||||||
input=Input.Connection,
|
|
||||||
title="Transformer",
|
|
||||||
)
|
|
||||||
vae: VAEField = InputField(
|
|
||||||
description=FieldDescriptions.vae,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
positive_text_conditioning: FluxConditioningField = InputField(
|
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
|
||||||
)
|
|
||||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
|
||||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
|
||||||
num_steps: int = InputField(
|
|
||||||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
|
||||||
)
|
|
||||||
guidance: float = InputField(
|
|
||||||
default=4.0,
|
|
||||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
|
||||||
)
|
|
||||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
latents = self._run_diffusion(context)
|
|
||||||
image = self._run_vae_decoding(context, latents)
|
|
||||||
image_dto = context.images.save(image=image)
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
def _run_diffusion(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
):
|
|
||||||
inference_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# Load the conditioning data.
|
|
||||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
|
||||||
assert len(cond_data.conditionings) == 1
|
|
||||||
flux_conditioning = cond_data.conditionings[0]
|
|
||||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
|
||||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
|
||||||
t5_embeddings = flux_conditioning.t5_embeds
|
|
||||||
clip_embeddings = flux_conditioning.clip_embeds
|
|
||||||
|
|
||||||
transformer_info = context.models.load(self.transformer.transformer)
|
|
||||||
|
|
||||||
# Prepare input noise.
|
|
||||||
x = get_noise(
|
|
||||||
num_samples=1,
|
|
||||||
height=self.height,
|
|
||||||
width=self.width,
|
|
||||||
device=TorchDevice.choose_torch_device(),
|
|
||||||
dtype=inference_dtype,
|
|
||||||
seed=self.seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
x, img_ids = prepare_latent_img_patches(x)
|
|
||||||
|
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
|
||||||
|
|
||||||
timesteps = get_schedule(
|
|
||||||
num_steps=self.num_steps,
|
|
||||||
image_seq_len=x.shape[1],
|
|
||||||
shift=not is_schnell,
|
|
||||||
)
|
|
||||||
|
|
||||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
|
||||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
|
||||||
|
|
||||||
with transformer_info as transformer:
|
|
||||||
assert isinstance(transformer, Flux)
|
|
||||||
|
|
||||||
def step_callback() -> None:
|
|
||||||
if context.util.is_canceled():
|
|
||||||
raise CanceledException
|
|
||||||
|
|
||||||
# TODO: Make this look like the image before re-enabling
|
|
||||||
# latent_image = unpack(img.float(), self.height, self.width)
|
|
||||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
|
||||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
|
||||||
|
|
||||||
# # Create a new tensor of the required shape [255, 255, 3]
|
|
||||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
|
||||||
|
|
||||||
# # Convert to a NumPy array and then to a PIL Image
|
|
||||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
|
||||||
|
|
||||||
# (width, height) = image.size
|
|
||||||
# width *= 8
|
|
||||||
# height *= 8
|
|
||||||
|
|
||||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
||||||
|
|
||||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
|
||||||
# context._services.events.emit_invocation_denoise_progress(
|
|
||||||
# context._data.queue_item,
|
|
||||||
# context._data.invocation,
|
|
||||||
# state,
|
|
||||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
|
||||||
# )
|
|
||||||
|
|
||||||
x = denoise(
|
|
||||||
model=transformer,
|
|
||||||
img=x,
|
|
||||||
img_ids=img_ids,
|
|
||||||
txt=t5_embeddings,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
vec=clip_embeddings,
|
|
||||||
timesteps=timesteps,
|
|
||||||
step_callback=step_callback,
|
|
||||||
guidance=self.guidance,
|
|
||||||
)
|
|
||||||
|
|
||||||
x = unpack(x.float(), self.height, self.width)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _run_vae_decoding(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
) -> Image.Image:
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
|
||||||
with vae_info as vae:
|
|
||||||
assert isinstance(vae, AutoEncoder)
|
|
||||||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
|
||||||
img = vae.decode(latents)
|
|
||||||
|
|
||||||
img = img.clamp(-1, 1)
|
|
||||||
img = rearrange(img[0], "c h w -> h w c")
|
|
||||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
|
||||||
|
|
||||||
return img_pil
|
|
||||||
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
60
invokeai/app/invocations/flux_vae_decode.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_vae_decode",
|
||||||
|
title="FLUX Latents to Image",
|
||||||
|
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||||
|
img = vae.decode(latents)
|
||||||
|
|
||||||
|
img = img.clamp(-1, 1)
|
||||||
|
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||||
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
|
return img_pil
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||||
|
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
image_dto = context.images.save(image=image)
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
67
invokeai/app/invocations/flux_vae_encode.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import VAEField
|
||||||
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
from invokeai.backend.model_manager import LoadedModel
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_vae_encode",
|
||||||
|
title="FLUX Image to Latents",
|
||||||
|
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||||
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(
|
||||||
|
description="The image to encode.",
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||||
|
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||||
|
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||||
|
# should be used for VAE encode sampling.
|
||||||
|
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
image_tensor = image_tensor.to(
|
||||||
|
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||||
|
)
|
||||||
|
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
|
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||||
|
|
||||||
|
latents = latents.to("cpu")
|
||||||
|
name = context.tensors.save(tensor=latents)
|
||||||
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||||
@@ -6,13 +6,19 @@ import cv2
|
|||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
Classification,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
ColorField,
|
ColorField,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
InputField,
|
InputField,
|
||||||
|
OutputField,
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
@@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("canvas_v2_mask_and_crop_output")
|
||||||
|
class CanvasV2MaskAndCropOutput(ImageOutput):
|
||||||
|
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
||||||
|
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"canvas_v2_mask_and_crop",
|
||||||
|
title="Canvas V2 Mask and Crop",
|
||||||
|
tags=["image", "mask", "id"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Handles Canvas V2 image output masking and cropping"""
|
||||||
|
|
||||||
|
source_image: ImageField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
|
||||||
|
)
|
||||||
|
generated_image: ImageField = InputField(description="The image to apply the mask to")
|
||||||
|
mask: ImageField = InputField(description="The mask to apply")
|
||||||
|
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||||
|
|
||||||
|
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||||
|
mask_array = numpy.array(mask)
|
||||||
|
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||||
|
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||||
|
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||||
|
if self.mask_blur > 0:
|
||||||
|
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||||
|
return ImageOps.invert(mask.convert("L"))
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||||
|
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||||
|
|
||||||
|
if self.source_image:
|
||||||
|
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||||
|
source_image = context.images.get_pil(self.source_image.image_name)
|
||||||
|
source_image.paste(generated_image, (0, 0), mask)
|
||||||
|
image_dto = context.images.save(image=source_image)
|
||||||
|
else:
|
||||||
|
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||||
|
generated_image.putalpha(mask)
|
||||||
|
image_dto = context.images.save(image=generated_image)
|
||||||
|
|
||||||
|
# bbox = image.getbbox()
|
||||||
|
# image = image.crop(bbox)
|
||||||
|
|
||||||
|
return CanvasV2MaskAndCropOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
offset_x=0,
|
||||||
|
offset_y=0,
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
|||||||
title="Tensor Mask to Image",
|
title="Tensor Mask to Image",
|
||||||
tags=["mask"],
|
tags=["mask"],
|
||||||
category="mask",
|
category="mask",
|
||||||
version="1.0.0",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Convert a mask tensor to an image."""
|
"""Convert a mask tensor to an image."""
|
||||||
@@ -135,6 +135,11 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
mask = context.tensors.load(self.mask.tensor_name)
|
mask = context.tensors.load(self.mask.tensor_name)
|
||||||
|
|
||||||
|
# Squeeze the channel dimension if it exists.
|
||||||
|
if mask.dim() == 3:
|
||||||
|
mask = mask.squeeze(0)
|
||||||
|
|
||||||
# Ensure that the mask is binary.
|
# Ensure that the mask is binary.
|
||||||
if mask.dtype != torch.bool:
|
if mask.dtype != torch.bool:
|
||||||
mask = mask > 0.5
|
mask = mask > 0.5
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class CLIPField(BaseModel):
|
|||||||
|
|
||||||
class TransformerField(BaseModel):
|
class TransformerField(BaseModel):
|
||||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||||
|
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderField(BaseModel):
|
class T5EncoderField(BaseModel):
|
||||||
@@ -202,7 +203,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
|
|
||||||
return FluxModelLoaderOutput(
|
return FluxModelLoaderOutput(
|
||||||
transformer=TransformerField(transformer=transformer),
|
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||||
vae=VAEField(vae=vae),
|
vae=VAEField(vae=vae),
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import UNetField
|
from invokeai.app.invocations.model import UNetField
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import LatentsOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||||
MultiDiffusionPipeline,
|
MultiDiffusionPipeline,
|
||||||
@@ -204,7 +204,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
# Load the UNet model.
|
# Load the UNet model.
|
||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
|
||||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
with (
|
||||||
|
ExitStack() as exit_stack,
|
||||||
|
unet_info as unet,
|
||||||
|
LoraPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||||
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ class QueueItemEventBase(QueueEventBase):
|
|||||||
|
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
|
origin: str | None = Field(default=None, description="The origin of the queue item")
|
||||||
|
destination: str | None = Field(default=None, description="The destination of the queue item")
|
||||||
|
|
||||||
|
|
||||||
class InvocationEventBase(QueueItemEventBase):
|
class InvocationEventBase(QueueItemEventBase):
|
||||||
@@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
|
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
queue_id: str = Field(description="The ID of the queue")
|
queue_id: str = Field(description="The ID of the queue")
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
@@ -114,6 +114,8 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
|
destination=queue_item.destination,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -147,6 +149,8 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
|
destination=queue_item.destination,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -184,6 +188,8 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
|
destination=queue_item.destination,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -216,6 +222,8 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
|
destination=queue_item.destination,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -253,6 +261,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
|
destination=queue_item.destination,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
status=queue_item.status,
|
status=queue_item.status,
|
||||||
error_type=queue_item.error_type,
|
error_type=queue_item.error_type,
|
||||||
@@ -279,12 +289,14 @@ class BatchEnqueuedEvent(QueueEventBase):
|
|||||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||||
)
|
)
|
||||||
priority: int = Field(description="The priority of the batch")
|
priority: int = Field(description="The priority of the batch")
|
||||||
|
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=enqueue_result.queue_id,
|
queue_id=enqueue_result.queue_id,
|
||||||
batch_id=enqueue_result.batch.batch_id,
|
batch_id=enqueue_result.batch.batch_id,
|
||||||
|
origin=enqueue_result.batch.origin,
|
||||||
enqueued=enqueue_result.enqueued,
|
enqueued=enqueue_result.enqueued,
|
||||||
requested=enqueue_result.requested,
|
requested=enqueue_result.requested,
|
||||||
priority=enqueue_result.priority,
|
priority=enqueue_result.priority,
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
|
|||||||
if self.variant:
|
if self.variant:
|
||||||
base += f":{self.variant or ''}"
|
base += f":{self.variant or ''}"
|
||||||
if self.subfolder:
|
if self.subfolder:
|
||||||
base += f":{self.subfolder}"
|
base += f"::{self.subfolder.as_posix()}"
|
||||||
return base
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
|
CancelByDestinationResult,
|
||||||
CancelByQueueIDResult,
|
CancelByQueueIDResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
|
|||||||
"""Cancels all queue items with matching batch IDs"""
|
"""Cancels all queue items with matching batch IDs"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||||
|
"""Cancels all queue items with the given batch destination"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
"""Cancels all queue items with matching queue ID"""
|
"""Cancels all queue items with matching queue ID"""
|
||||||
|
|||||||
@@ -77,6 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
|||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||||
|
origin: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||||
|
)
|
||||||
|
destination: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||||
|
)
|
||||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||||
graph: Graph = Field(description="The graph to initialize the session with")
|
graph: Graph = Field(description="The graph to initialize the session with")
|
||||||
workflow: Optional[WorkflowWithoutID] = Field(
|
workflow: Optional[WorkflowWithoutID] = Field(
|
||||||
@@ -195,6 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||||
priority: int = Field(default=0, description="The priority of this queue item")
|
priority: int = Field(default=0, description="The priority of this queue item")
|
||||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||||
|
origin: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||||
|
)
|
||||||
|
destination: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||||
|
)
|
||||||
session_id: str = Field(
|
session_id: str = Field(
|
||||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||||
)
|
)
|
||||||
@@ -294,6 +310,8 @@ class SessionQueueStatus(BaseModel):
|
|||||||
class BatchStatus(BaseModel):
|
class BatchStatus(BaseModel):
|
||||||
queue_id: str = Field(..., description="The ID of the queue")
|
queue_id: str = Field(..., description="The ID of the queue")
|
||||||
batch_id: str = Field(..., description="The ID of the batch")
|
batch_id: str = Field(..., description="The ID of the batch")
|
||||||
|
origin: str | None = Field(..., description="The origin of the batch")
|
||||||
|
destination: str | None = Field(..., description="The destination of the batch")
|
||||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||||
@@ -328,6 +346,12 @@ class CancelByBatchIDsResult(BaseModel):
|
|||||||
canceled: int = Field(..., description="Number of queue items canceled")
|
canceled: int = Field(..., description="Number of queue items canceled")
|
||||||
|
|
||||||
|
|
||||||
|
class CancelByDestinationResult(CancelByBatchIDsResult):
|
||||||
|
"""Result of canceling by a destination"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||||
"""Result of canceling by queue id"""
|
"""Result of canceling by queue id"""
|
||||||
|
|
||||||
@@ -433,6 +457,8 @@ class SessionQueueValueToInsert(NamedTuple):
|
|||||||
field_values: Optional[str] # field_values json
|
field_values: Optional[str] # field_values json
|
||||||
priority: int # priority
|
priority: int # priority
|
||||||
workflow: Optional[str] # workflow json
|
workflow: Optional[str] # workflow json
|
||||||
|
origin: str | None
|
||||||
|
destination: str | None
|
||||||
|
|
||||||
|
|
||||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||||
@@ -453,6 +479,8 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
|||||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||||
priority, # priority
|
priority, # priority
|
||||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||||
|
batch.origin, # origin
|
||||||
|
batch.destination, # destination
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return values_to_insert
|
return values_to_insert
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
|
CancelByDestinationResult,
|
||||||
CancelByQueueIDResult,
|
CancelByQueueIDResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
|
|
||||||
self.__cursor.executemany(
|
self.__cursor.executemany(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
values_to_insert,
|
values_to_insert,
|
||||||
)
|
)
|
||||||
@@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
)
|
)
|
||||||
self.__conn.commit()
|
self.__conn.commit()
|
||||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
|
||||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
|
||||||
current_queue_item, batch_status, queue_status
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self.__conn.rollback()
|
self.__conn.rollback()
|
||||||
raise
|
raise
|
||||||
@@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
return CancelByBatchIDsResult(canceled=count)
|
return CancelByBatchIDsResult(canceled=count)
|
||||||
|
|
||||||
|
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||||
|
try:
|
||||||
|
current_queue_item = self.get_current(queue_id)
|
||||||
|
self.__lock.acquire()
|
||||||
|
where = """--sql
|
||||||
|
WHERE
|
||||||
|
queue_id == ?
|
||||||
|
AND destination == ?
|
||||||
|
AND status != 'canceled'
|
||||||
|
AND status != 'completed'
|
||||||
|
AND status != 'failed'
|
||||||
|
"""
|
||||||
|
params = (queue_id, destination)
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM session_queue
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
count = self.__cursor.fetchone()[0]
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET status = 'canceled'
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||||
|
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return CancelByDestinationResult(canceled=count)
|
||||||
|
|
||||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
try:
|
try:
|
||||||
current_queue_item = self.get_current(queue_id)
|
current_queue_item = self.get_current(queue_id)
|
||||||
@@ -541,7 +578,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
started_at,
|
started_at,
|
||||||
session_id,
|
session_id,
|
||||||
batch_id,
|
batch_id,
|
||||||
queue_id
|
queue_id,
|
||||||
|
origin,
|
||||||
|
destination
|
||||||
FROM session_queue
|
FROM session_queue
|
||||||
WHERE queue_id = ?
|
WHERE queue_id = ?
|
||||||
"""
|
"""
|
||||||
@@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.acquire()
|
self.__lock.acquire()
|
||||||
self.__cursor.execute(
|
self.__cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT status, count(*)
|
SELECT status, count(*), origin, destination
|
||||||
FROM session_queue
|
FROM session_queue
|
||||||
WHERE
|
WHERE
|
||||||
queue_id = ?
|
queue_id = ?
|
||||||
@@ -633,6 +672,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
total = sum(row[1] for row in result)
|
total = sum(row[1] for row in result)
|
||||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||||
|
origin = result[0]["origin"] if result else None
|
||||||
|
destination = result[0]["destination"] if result else None
|
||||||
except Exception:
|
except Exception:
|
||||||
self.__conn.rollback()
|
self.__conn.rollback()
|
||||||
raise
|
raise
|
||||||
@@ -641,6 +682,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
|
|
||||||
return BatchStatus(
|
return BatchStatus(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
|
origin=origin,
|
||||||
|
destination=destination,
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
pending=counts.get("pending", 0),
|
pending=counts.get("pending", 0),
|
||||||
in_progress=counts.get("in_progress", 0),
|
in_progress=counts.get("in_progress", 0),
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
|||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -557,6 +557,24 @@ class UtilInterface(InvocationContextInterface):
|
|||||||
is_canceled=self.is_canceled,
|
is_canceled=self.is_canceled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def flux_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
|
||||||
|
"""
|
||||||
|
The step callback emits a progress event with the current step, the total number of
|
||||||
|
steps, a preview image, and some other internal metadata.
|
||||||
|
|
||||||
|
This should be called after each denoising step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
flux_step_callback(
|
||||||
|
context_data=self._data,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
|
events=self._services.events,
|
||||||
|
is_canceled=self.is_canceled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
"""Provides access to various services and data for the current invocation.
|
"""Provides access to various services and data for the current invocation.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_12(app_config=config))
|
migrator.register_migration(build_migration_12(app_config=config))
|
||||||
migrator.register_migration(build_migration_13())
|
migrator.register_migration(build_migration_13())
|
||||||
migrator.register_migration(build_migration_14())
|
migrator.register_migration(build_migration_14())
|
||||||
|
migrator.register_migration(build_migration_15())
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
|
||||||
|
class Migration15Callback:
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._add_origin_col(cursor)
|
||||||
|
|
||||||
|
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""
|
||||||
|
- Adds `origin` column to the session queue table.
|
||||||
|
- Adds `destination` column to the session queue table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_15() -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 14 to 15.
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Adds `origin` column to the session queue table.
|
||||||
|
- Adds `destination` column to the session queue table.
|
||||||
|
"""
|
||||||
|
migration_15 = Migration(
|
||||||
|
from_version=14,
|
||||||
|
to_version=15,
|
||||||
|
callback=Migration15Callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_15
|
||||||
@@ -0,0 +1,407 @@
|
|||||||
|
{
|
||||||
|
"name": "FLUX Image to Image",
|
||||||
|
"author": "InvokeAI",
|
||||||
|
"description": "A simple image-to-image workflow using a FLUX dev model. ",
|
||||||
|
"version": "1.0.4",
|
||||||
|
"contact": "",
|
||||||
|
"tags": "image2image, flux, image-to-image",
|
||||||
|
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
|
||||||
|
"exposedFields": [
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "t5_encoder_model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "clip_embed_model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "vae_model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"fieldName": "denoising_start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"fieldName": "prompt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"fieldName": "num_steps"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"version": "3.0.0",
|
||||||
|
"category": "default"
|
||||||
|
},
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||||
|
"type": "flux_vae_encode",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"image": {
|
||||||
|
"name": "image",
|
||||||
|
"label": "",
|
||||||
|
"value": {
|
||||||
|
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"vae": {
|
||||||
|
"name": "vae",
|
||||||
|
"label": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 732.7680166609682,
|
||||||
|
"y": -24.37398171806909
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"type": "flux_denoise",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"board": {
|
||||||
|
"name": "board",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"name": "metadata",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"latents": {
|
||||||
|
"name": "latents",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"denoise_mask": {
|
||||||
|
"name": "denoise_mask",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"denoising_start": {
|
||||||
|
"name": "denoising_start",
|
||||||
|
"label": "",
|
||||||
|
"value": 0.04
|
||||||
|
},
|
||||||
|
"denoising_end": {
|
||||||
|
"name": "denoising_end",
|
||||||
|
"label": "",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"transformer": {
|
||||||
|
"name": "transformer",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"positive_text_conditioning": {
|
||||||
|
"name": "positive_text_conditioning",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"width": {
|
||||||
|
"name": "width",
|
||||||
|
"label": "",
|
||||||
|
"value": 1024
|
||||||
|
},
|
||||||
|
"height": {
|
||||||
|
"name": "height",
|
||||||
|
"label": "",
|
||||||
|
"value": 1024
|
||||||
|
},
|
||||||
|
"num_steps": {
|
||||||
|
"name": "num_steps",
|
||||||
|
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||||
|
"value": 30
|
||||||
|
},
|
||||||
|
"guidance": {
|
||||||
|
"name": "guidance",
|
||||||
|
"label": "",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"name": "seed",
|
||||||
|
"label": "",
|
||||||
|
"value": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 1182.8836633018684,
|
||||||
|
"y": -251.38882958913183
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"type": "flux_vae_decode",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": false,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"board": {
|
||||||
|
"name": "board",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"name": "metadata",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"latents": {
|
||||||
|
"name": "latents",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"vae": {
|
||||||
|
"name": "vae",
|
||||||
|
"label": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 1575.5797431839133,
|
||||||
|
"y": -209.00150975507415
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"type": "flux_model_loader",
|
||||||
|
"version": "1.0.4",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": false,
|
||||||
|
"inputs": {
|
||||||
|
"model": {
|
||||||
|
"name": "model",
|
||||||
|
"label": "Model (dev variant recommended for Image-to-Image)"
|
||||||
|
},
|
||||||
|
"t5_encoder_model": {
|
||||||
|
"name": "t5_encoder_model",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"clip_embed_model": {
|
||||||
|
"name": "clip_embed_model",
|
||||||
|
"label": "",
|
||||||
|
"value": {
|
||||||
|
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
|
||||||
|
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
|
||||||
|
"name": "clip-vit-large-patch14",
|
||||||
|
"base": "any",
|
||||||
|
"type": "clip_embed"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"vae_model": {
|
||||||
|
"name": "vae_model",
|
||||||
|
"label": "",
|
||||||
|
"value": {
|
||||||
|
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
|
||||||
|
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
|
||||||
|
"name": "FLUX.1-schnell_ae",
|
||||||
|
"base": "flux",
|
||||||
|
"type": "vae"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 328.1809894659957,
|
||||||
|
"y": -90.2241133566946
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"type": "flux_text_encoder",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"clip": {
|
||||||
|
"name": "clip",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"t5_encoder": {
|
||||||
|
"name": "t5_encoder",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"t5_max_seq_len": {
|
||||||
|
"name": "t5_max_seq_len",
|
||||||
|
"label": "T5 Max Seq Len",
|
||||||
|
"value": 256
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"name": "prompt",
|
||||||
|
"label": "",
|
||||||
|
"value": "a cat wearing a birthday hat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 745.8823365057267,
|
||||||
|
"y": -299.60249175851914
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"type": "rand_int",
|
||||||
|
"version": "1.0.1",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": false,
|
||||||
|
"inputs": {
|
||||||
|
"low": {
|
||||||
|
"name": "low",
|
||||||
|
"label": "",
|
||||||
|
"value": 0
|
||||||
|
},
|
||||||
|
"high": {
|
||||||
|
"name": "high",
|
||||||
|
"label": "",
|
||||||
|
"value": 2147483647
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 725.834098928012,
|
||||||
|
"y": 496.2710031089931
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
|
||||||
|
"type": "default",
|
||||||
|
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||||
|
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"sourceHandle": "height",
|
||||||
|
"targetHandle": "height"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
|
||||||
|
"type": "default",
|
||||||
|
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||||
|
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"sourceHandle": "width",
|
||||||
|
"targetHandle": "width"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
|
||||||
|
"type": "default",
|
||||||
|
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||||
|
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"sourceHandle": "latents",
|
||||||
|
"targetHandle": "latents"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||||
|
"sourceHandle": "vae",
|
||||||
|
"targetHandle": "vae"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||||
|
"type": "default",
|
||||||
|
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"sourceHandle": "latents",
|
||||||
|
"targetHandle": "latents"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"sourceHandle": "value",
|
||||||
|
"targetHandle": "seed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"sourceHandle": "transformer",
|
||||||
|
"targetHandle": "transformer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
|
||||||
|
"type": "default",
|
||||||
|
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||||
|
"sourceHandle": "conditioning",
|
||||||
|
"targetHandle": "positive_text_conditioning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"sourceHandle": "vae",
|
||||||
|
"targetHandle": "vae"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"sourceHandle": "max_seq_len",
|
||||||
|
"targetHandle": "t5_max_seq_len"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"sourceHandle": "t5_encoder",
|
||||||
|
"targetHandle": "t5_encoder"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"sourceHandle": "clip",
|
||||||
|
"targetHandle": "clip"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "FLUX Text to Image",
|
"name": "FLUX Text to Image",
|
||||||
"author": "InvokeAI",
|
"author": "InvokeAI",
|
||||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
|
||||||
"version": "1.0.4",
|
"version": "1.0.4",
|
||||||
"contact": "",
|
"contact": "",
|
||||||
"tags": "text2image, flux",
|
"tags": "text2image, flux",
|
||||||
@@ -11,17 +11,25 @@
|
|||||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"fieldName": "model"
|
"fieldName": "model"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "t5_encoder_model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "clip_embed_model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"fieldName": "vae_model"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
"fieldName": "prompt"
|
"fieldName": "prompt"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
"fieldName": "num_steps"
|
"fieldName": "num_steps"
|
||||||
},
|
|
||||||
{
|
|
||||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
|
||||||
"fieldName": "t5_encoder_model"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"meta": {
|
"meta": {
|
||||||
@@ -29,6 +37,121 @@
|
|||||||
"category": "default"
|
"category": "default"
|
||||||
},
|
},
|
||||||
"nodes": [
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
|
"type": "flux_denoise",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"board": {
|
||||||
|
"name": "board",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"name": "metadata",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"latents": {
|
||||||
|
"name": "latents",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"denoise_mask": {
|
||||||
|
"name": "denoise_mask",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"denoising_start": {
|
||||||
|
"name": "denoising_start",
|
||||||
|
"label": "",
|
||||||
|
"value": 0
|
||||||
|
},
|
||||||
|
"denoising_end": {
|
||||||
|
"name": "denoising_end",
|
||||||
|
"label": "",
|
||||||
|
"value": 1
|
||||||
|
},
|
||||||
|
"transformer": {
|
||||||
|
"name": "transformer",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"positive_text_conditioning": {
|
||||||
|
"name": "positive_text_conditioning",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"width": {
|
||||||
|
"name": "width",
|
||||||
|
"label": "",
|
||||||
|
"value": 1024
|
||||||
|
},
|
||||||
|
"height": {
|
||||||
|
"name": "height",
|
||||||
|
"label": "",
|
||||||
|
"value": 1024
|
||||||
|
},
|
||||||
|
"num_steps": {
|
||||||
|
"name": "num_steps",
|
||||||
|
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||||
|
"value": 30
|
||||||
|
},
|
||||||
|
"guidance": {
|
||||||
|
"name": "guidance",
|
||||||
|
"label": "",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"name": "seed",
|
||||||
|
"label": "",
|
||||||
|
"value": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 1186.1868226120378,
|
||||||
|
"y": -214.9459927686657
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"type": "flux_vae_decode",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": false,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"board": {
|
||||||
|
"name": "board",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"name": "metadata",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"latents": {
|
||||||
|
"name": "latents",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"vae": {
|
||||||
|
"name": "vae",
|
||||||
|
"label": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 1575.5797431839133,
|
||||||
|
"y": -209.00150975507415
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
"type": "invocation",
|
"type": "invocation",
|
||||||
@@ -99,8 +222,8 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"position": {
|
"position": {
|
||||||
"x": 824.1970602278849,
|
"x": 778.4899149328337,
|
||||||
"y": 146.98251001061735
|
"y": -100.36469216659502
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -129,77 +252,52 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"position": {
|
"position": {
|
||||||
"x": 822.9899179655476,
|
"x": 800.9667463219505,
|
||||||
"y": 360.9657214885052
|
"y": 285.8297267547506
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"type": "invocation",
|
|
||||||
"data": {
|
|
||||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"type": "flux_text_to_image",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"label": "",
|
|
||||||
"notes": "",
|
|
||||||
"isOpen": true,
|
|
||||||
"isIntermediate": false,
|
|
||||||
"useCache": true,
|
|
||||||
"inputs": {
|
|
||||||
"board": {
|
|
||||||
"name": "board",
|
|
||||||
"label": ""
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"name": "metadata",
|
|
||||||
"label": ""
|
|
||||||
},
|
|
||||||
"transformer": {
|
|
||||||
"name": "transformer",
|
|
||||||
"label": ""
|
|
||||||
},
|
|
||||||
"vae": {
|
|
||||||
"name": "vae",
|
|
||||||
"label": ""
|
|
||||||
},
|
|
||||||
"positive_text_conditioning": {
|
|
||||||
"name": "positive_text_conditioning",
|
|
||||||
"label": ""
|
|
||||||
},
|
|
||||||
"width": {
|
|
||||||
"name": "width",
|
|
||||||
"label": "",
|
|
||||||
"value": 1024
|
|
||||||
},
|
|
||||||
"height": {
|
|
||||||
"name": "height",
|
|
||||||
"label": "",
|
|
||||||
"value": 1024
|
|
||||||
},
|
|
||||||
"num_steps": {
|
|
||||||
"name": "num_steps",
|
|
||||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
|
||||||
"value": 30
|
|
||||||
},
|
|
||||||
"guidance": {
|
|
||||||
"name": "guidance",
|
|
||||||
"label": "",
|
|
||||||
"value": 4
|
|
||||||
},
|
|
||||||
"seed": {
|
|
||||||
"name": "seed",
|
|
||||||
"label": "",
|
|
||||||
"value": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"position": {
|
|
||||||
"x": 1216.3900791301849,
|
|
||||||
"y": 5.500841807102248
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"edges": [
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
|
"sourceHandle": "transformer",
|
||||||
|
"targetHandle": "transformer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
|
||||||
|
"type": "default",
|
||||||
|
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
|
"sourceHandle": "conditioning",
|
||||||
|
"targetHandle": "positive_text_conditioning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
|
"sourceHandle": "value",
|
||||||
|
"targetHandle": "seed"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||||
|
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"sourceHandle": "latents",
|
||||||
|
"targetHandle": "latents"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||||
|
"type": "default",
|
||||||
|
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||||
|
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||||
|
"sourceHandle": "vae",
|
||||||
|
"targetHandle": "vae"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
@@ -208,14 +306,6 @@
|
|||||||
"sourceHandle": "max_seq_len",
|
"sourceHandle": "max_seq_len",
|
||||||
"targetHandle": "t5_max_seq_len"
|
"targetHandle": "t5_max_seq_len"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
|
||||||
"type": "default",
|
|
||||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
|
||||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"sourceHandle": "vae",
|
|
||||||
"targetHandle": "vae"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||||
"type": "default",
|
"type": "default",
|
||||||
@@ -231,30 +321,6 @@
|
|||||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
"sourceHandle": "clip",
|
"sourceHandle": "clip",
|
||||||
"targetHandle": "clip"
|
"targetHandle": "clip"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
|
||||||
"type": "default",
|
|
||||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
|
||||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"sourceHandle": "transformer",
|
|
||||||
"targetHandle": "transformer"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
|
||||||
"type": "default",
|
|
||||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
|
||||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"sourceHandle": "conditioning",
|
|
||||||
"targetHandle": "positive_text_conditioning"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
|
||||||
"type": "default",
|
|
||||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
|
||||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
|
||||||
"sourceHandle": "value",
|
|
||||||
"targetHandle": "seed"
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,25 @@ SD1_5_LATENT_RGB_FACTORS = [
|
|||||||
[-0.1307, -0.1874, -0.7445], # L4
|
[-0.1307, -0.1874, -0.7445], # L4
|
||||||
]
|
]
|
||||||
|
|
||||||
|
FLUX_LATENT_RGB_FACTORS = [
|
||||||
|
[-0.0412, 0.0149, 0.0521],
|
||||||
|
[0.0056, 0.0291, 0.0768],
|
||||||
|
[0.0342, -0.0681, -0.0427],
|
||||||
|
[-0.0258, 0.0092, 0.0463],
|
||||||
|
[0.0863, 0.0784, 0.0547],
|
||||||
|
[-0.0017, 0.0402, 0.0158],
|
||||||
|
[0.0501, 0.1058, 0.1152],
|
||||||
|
[-0.0209, -0.0218, -0.0329],
|
||||||
|
[-0.0314, 0.0083, 0.0896],
|
||||||
|
[0.0851, 0.0665, -0.0472],
|
||||||
|
[-0.0534, 0.0238, -0.0024],
|
||||||
|
[0.0452, -0.0026, 0.0048],
|
||||||
|
[0.0892, 0.0831, 0.0881],
|
||||||
|
[-0.1117, -0.0304, -0.0789],
|
||||||
|
[0.0027, -0.0479, -0.0043],
|
||||||
|
[-0.1146, -0.0827, -0.0598],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(
|
def sample_to_lowres_estimated_image(
|
||||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||||
@@ -94,3 +113,32 @@ def stable_diffusion_step_callback(
|
|||||||
intermediate_state,
|
intermediate_state,
|
||||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flux_step_callback(
|
||||||
|
context_data: "InvocationContextData",
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
events: "EventServiceBase",
|
||||||
|
is_canceled: Callable[[], bool],
|
||||||
|
) -> None:
|
||||||
|
if is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||||
|
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
|
||||||
|
latent_image = latent_image_perm @ latent_rgb_factors
|
||||||
|
latents_ubyte = (
|
||||||
|
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||||
|
).to(device="cpu", dtype=torch.uint8)
|
||||||
|
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
events.emit_invocation_denoise_progress(
|
||||||
|
context_data.queue_item,
|
||||||
|
context_data.invocation,
|
||||||
|
intermediate_state,
|
||||||
|
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||||
|
)
|
||||||
|
|||||||
56
invokeai/backend/flux/denoise.py
Normal file
56
invokeai/backend/flux/denoise.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(
|
||||||
|
model: Flux,
|
||||||
|
# model input
|
||||||
|
img: torch.Tensor,
|
||||||
|
img_ids: torch.Tensor,
|
||||||
|
txt: torch.Tensor,
|
||||||
|
txt_ids: torch.Tensor,
|
||||||
|
vec: torch.Tensor,
|
||||||
|
# sampling parameters
|
||||||
|
timesteps: list[float],
|
||||||
|
step_callback: Callable[[PipelineIntermediateState], None],
|
||||||
|
guidance: float,
|
||||||
|
inpaint_extension: InpaintExtension | None,
|
||||||
|
):
|
||||||
|
step = 0
|
||||||
|
# guidance_vec is ignored for schnell.
|
||||||
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||||
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
pred = model(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
)
|
||||||
|
preview_img = img - t_curr * pred
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
|
if inpaint_extension is not None:
|
||||||
|
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||||
|
|
||||||
|
step_callback(
|
||||||
|
PipelineIntermediateState(
|
||||||
|
step=step,
|
||||||
|
order=1,
|
||||||
|
total_steps=len(timesteps),
|
||||||
|
timestep=int(t_curr),
|
||||||
|
latents=preview_img,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
return img
|
||||||
35
invokeai/backend/flux/inpaint_extension.py
Normal file
35
invokeai/backend/flux/inpaint_extension.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintExtension:
|
||||||
|
"""A class for managing inpainting with FLUX."""
|
||||||
|
|
||||||
|
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||||
|
"""Initialize InpaintExtension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||||
|
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||||
|
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||||
|
inpainted region with the background. In 'packed' format.
|
||||||
|
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||||
|
"""
|
||||||
|
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||||
|
self._init_latents = init_latents
|
||||||
|
self._inpaint_mask = inpaint_mask
|
||||||
|
self._noise = noise
|
||||||
|
|
||||||
|
def merge_intermediate_latents_with_init_latents(
|
||||||
|
self, intermediate_latents: torch.Tensor, timestep: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||||
|
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||||
|
trajectory.
|
||||||
|
|
||||||
|
This function should be called after each denoising step.
|
||||||
|
"""
|
||||||
|
# Noise the init latents for the current timestep.
|
||||||
|
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||||
|
|
||||||
|
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||||
|
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||||
@@ -258,16 +258,17 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DiagonalGaussian(nn.Module):
|
class DiagonalGaussian(nn.Module):
|
||||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
def __init__(self, chunk_dim: int = 1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample = sample
|
|
||||||
self.chunk_dim = chunk_dim
|
self.chunk_dim = chunk_dim
|
||||||
|
|
||||||
def forward(self, z: Tensor) -> Tensor:
|
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||||
if self.sample:
|
if sample:
|
||||||
std = torch.exp(0.5 * logvar)
|
std = torch.exp(0.5 * logvar)
|
||||||
return mean + std * torch.randn_like(mean)
|
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||||
|
# have to use torch.randn(...) instead.
|
||||||
|
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||||
else:
|
else:
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
@@ -297,8 +298,21 @@ class AutoEncoder(nn.Module):
|
|||||||
self.scale_factor = params.scale_factor
|
self.scale_factor = params.scale_factor
|
||||||
self.shift_factor = params.shift_factor
|
self.shift_factor = params.shift_factor
|
||||||
|
|
||||||
def encode(self, x: Tensor) -> Tensor:
|
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||||
z = self.reg(self.encoder(x))
|
"""Run VAE encoding on input tensor x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||||
|
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||||
|
Defaults to True.
|
||||||
|
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||||
z = self.scale_factor * (z - self.shift_factor)
|
z = self.scale_factor * (z - self.shift_factor)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torch import Tensor
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from invokeai.backend.flux.model import Flux
|
|
||||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_noise(
|
|
||||||
num_samples: int,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
):
|
|
||||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
|
||||||
rand_device = "cpu"
|
|
||||||
rand_dtype = torch.float16
|
|
||||||
return torch.randn(
|
|
||||||
num_samples,
|
|
||||||
16,
|
|
||||||
# allow for packing
|
|
||||||
2 * math.ceil(height / 16),
|
|
||||||
2 * math.ceil(width / 16),
|
|
||||||
device=rand_device,
|
|
||||||
dtype=rand_dtype,
|
|
||||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
|
||||||
).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
|
||||||
bs, c, h, w = img.shape
|
|
||||||
if bs == 1 and not isinstance(prompt, str):
|
|
||||||
bs = len(prompt)
|
|
||||||
|
|
||||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
if img.shape[0] == 1 and bs > 1:
|
|
||||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
|
||||||
|
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
prompt = [prompt]
|
|
||||||
txt = t5(prompt)
|
|
||||||
if txt.shape[0] == 1 and bs > 1:
|
|
||||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
|
||||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
|
||||||
|
|
||||||
vec = clip(prompt)
|
|
||||||
if vec.shape[0] == 1 and bs > 1:
|
|
||||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"img": img,
|
|
||||||
"img_ids": img_ids.to(img.device),
|
|
||||||
"txt": txt.to(img.device),
|
|
||||||
"txt_ids": txt_ids.to(img.device),
|
|
||||||
"vec": vec.to(img.device),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
|
||||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
||||||
|
|
||||||
|
|
||||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
|
||||||
m = (y2 - y1) / (x2 - x1)
|
|
||||||
b = y1 - m * x1
|
|
||||||
return lambda x: m * x + b
|
|
||||||
|
|
||||||
|
|
||||||
def get_schedule(
|
|
||||||
num_steps: int,
|
|
||||||
image_seq_len: int,
|
|
||||||
base_shift: float = 0.5,
|
|
||||||
max_shift: float = 1.15,
|
|
||||||
shift: bool = True,
|
|
||||||
) -> list[float]:
|
|
||||||
# extra step for zero
|
|
||||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
||||||
|
|
||||||
# shifting the schedule to favor high timesteps for higher signal images
|
|
||||||
if shift:
|
|
||||||
# eastimate mu based on linear estimation between two points
|
|
||||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
|
||||||
timesteps = time_shift(mu, 1.0, timesteps)
|
|
||||||
|
|
||||||
return timesteps.tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def denoise(
|
|
||||||
model: Flux,
|
|
||||||
# model input
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
vec: Tensor,
|
|
||||||
# sampling parameters
|
|
||||||
timesteps: list[float],
|
|
||||||
step_callback: Callable[[], None],
|
|
||||||
guidance: float = 4.0,
|
|
||||||
):
|
|
||||||
# guidance_vec is ignored for schnell.
|
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
|
||||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
||||||
pred = model(
|
|
||||||
img=img,
|
|
||||||
img_ids=img_ids,
|
|
||||||
txt=txt,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=vec,
|
|
||||||
timesteps=t_vec,
|
|
||||||
guidance=guidance_vec,
|
|
||||||
)
|
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
|
||||||
step_callback()
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
|
||||||
return rearrange(
|
|
||||||
x,
|
|
||||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
||||||
h=math.ceil(height / 16),
|
|
||||||
w=math.ceil(width / 16),
|
|
||||||
ph=2,
|
|
||||||
pw=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Convert an input image in latent space to patches for diffusion.
|
|
||||||
|
|
||||||
This implementation was extracted from:
|
|
||||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
|
||||||
"""
|
|
||||||
bs, c, h, w = latent_img.shape
|
|
||||||
|
|
||||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
|
||||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
if img.shape[0] == 1 and bs > 1:
|
|
||||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
|
||||||
|
|
||||||
# Generate patch position ids.
|
|
||||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
return img, img_ids
|
|
||||||
135
invokeai/backend/flux/sampling_utils.py
Normal file
135
invokeai/backend/flux/sampling_utils.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(
|
||||||
|
num_samples: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
):
|
||||||
|
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||||
|
rand_device = "cpu"
|
||||||
|
rand_dtype = torch.float16
|
||||||
|
return torch.randn(
|
||||||
|
num_samples,
|
||||||
|
16,
|
||||||
|
# allow for packing
|
||||||
|
2 * math.ceil(height / 16),
|
||||||
|
2 * math.ceil(width / 16),
|
||||||
|
device=rand_device,
|
||||||
|
dtype=rand_dtype,
|
||||||
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(
|
||||||
|
num_steps: int,
|
||||||
|
image_seq_len: int,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
shift: bool = True,
|
||||||
|
) -> list[float]:
|
||||||
|
# extra step for zero
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
|
||||||
|
# shifting the schedule to favor high timesteps for higher signal images
|
||||||
|
if shift:
|
||||||
|
# estimate mu based on linear estimation between two points
|
||||||
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||||
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||||
|
"""Find the last index in timesteps that is >= val.
|
||||||
|
|
||||||
|
We use epsilon-close equality to avoid potential floating point errors.
|
||||||
|
"""
|
||||||
|
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||||
|
assert idx >= 0
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||||
|
"""Clip the timestep schedule to the denoising range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||||
|
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||||
|
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||||
|
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||||
|
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: The clipped timestep schedule.
|
||||||
|
"""
|
||||||
|
assert 0.0 <= denoising_start <= 1.0
|
||||||
|
assert 0.0 <= denoising_end <= 1.0
|
||||||
|
assert denoising_start <= denoising_end
|
||||||
|
|
||||||
|
t_start_val = 1.0 - denoising_start
|
||||||
|
t_end_val = 1.0 - denoising_end
|
||||||
|
|
||||||
|
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||||
|
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||||
|
|
||||||
|
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||||
|
|
||||||
|
return clipped_timesteps
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""Unpack flat array of patch embeddings to latent image."""
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||||
|
h=math.ceil(height / 16),
|
||||||
|
w=math.ceil(width / 16),
|
||||||
|
ph=2,
|
||||||
|
pw=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pack latent image to flattented array of patch embeddings."""
|
||||||
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||||
|
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
"""Generate tensor of image position ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (int): Height of image in latent space.
|
||||||
|
w (int): Width of image in latent space.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
device (torch.device): Device.
|
||||||
|
dtype (torch.dtype): dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Image position ids.
|
||||||
|
"""
|
||||||
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||||
|
return img_ids
|
||||||
@@ -1,672 +0,0 @@
|
|||||||
# Copyright (c) 2024 The InvokeAI Development team
|
|
||||||
"""LoRA model support."""
|
|
||||||
|
|
||||||
import bisect
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
|
||||||
from invokeai.backend.raw_model import RawModel
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
|
||||||
# rank: Optional[int]
|
|
||||||
# alpha: Optional[float]
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
# layer_key: str
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def scale(self):
|
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
if "alpha" in values:
|
|
||||||
self.alpha = values["alpha"].item()
|
|
||||||
else:
|
|
||||||
self.alpha = None
|
|
||||||
|
|
||||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
|
||||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
|
||||||
values["bias_indices"],
|
|
||||||
values["bias_values"],
|
|
||||||
tuple(values["bias_size"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
|
||||||
self.layer_key = layer_key
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
|
||||||
return self.bias
|
|
||||||
|
|
||||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
|
||||||
params = {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
bias = self.get_bias(orig_module.bias)
|
|
||||||
if bias is not None:
|
|
||||||
params["bias"] = bias
|
|
||||||
return params
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for val in [self.bias]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
|
||||||
"""Log a warning if values contains unhandled keys."""
|
|
||||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
|
||||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
|
||||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
|
||||||
unknown_keys = set(values.keys()) - all_known_keys
|
|
||||||
if unknown_keys:
|
|
||||||
logger.warning(
|
|
||||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
|
||||||
class LoRALayer(LoRALayerBase):
|
|
||||||
# up: torch.Tensor
|
|
||||||
# mid: Optional[torch.Tensor]
|
|
||||||
# down: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
|
||||||
self.down = values["lora_down.weight"]
|
|
||||||
self.mid = values.get("lora_mid.weight", None)
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
|
||||||
self.check_keys(
|
|
||||||
values,
|
|
||||||
{
|
|
||||||
"lora_up.weight",
|
|
||||||
"lora_down.weight",
|
|
||||||
"lora_mid.weight",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.mid is not None:
|
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
||||||
else:
|
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.up, self.mid, self.down]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.mid is not None:
|
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
|
||||||
# w1_a: torch.Tensor
|
|
||||||
# w1_b: torch.Tensor
|
|
||||||
# w2_a: torch.Tensor
|
|
||||||
# w2_b: torch.Tensor
|
|
||||||
# t1: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.w1_a = values["hada_w1_a"]
|
|
||||||
self.w1_b = values["hada_w1_b"]
|
|
||||||
self.w2_a = values["hada_w2_a"]
|
|
||||||
self.w2_b = values["hada_w2_b"]
|
|
||||||
self.t1 = values.get("hada_t1", None)
|
|
||||||
self.t2 = values.get("hada_t2", None)
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
|
||||||
self.check_keys(
|
|
||||||
values,
|
|
||||||
{
|
|
||||||
"hada_w1_a",
|
|
||||||
"hada_w1_b",
|
|
||||||
"hada_w2_a",
|
|
||||||
"hada_w2_b",
|
|
||||||
"hada_t1",
|
|
||||||
"hada_t2",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.t1 is None:
|
|
||||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
|
||||||
|
|
||||||
else:
|
|
||||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
|
||||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
|
||||||
weight = rebuild1 * rebuild2
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t1 is not None:
|
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
|
||||||
# w1: Optional[torch.Tensor] = None
|
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
|
||||||
# w1_b: Optional[torch.Tensor] = None
|
|
||||||
# w2: Optional[torch.Tensor] = None
|
|
||||||
# w2_a: Optional[torch.Tensor] = None
|
|
||||||
# w2_b: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.w1 = values.get("lokr_w1", None)
|
|
||||||
if self.w1 is None:
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
|
||||||
self.w1_b = values["lokr_w1_b"]
|
|
||||||
else:
|
|
||||||
self.w1_b = None
|
|
||||||
self.w1_a = None
|
|
||||||
|
|
||||||
self.w2 = values.get("lokr_w2", None)
|
|
||||||
if self.w2 is None:
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
|
||||||
self.w2_b = values["lokr_w2_b"]
|
|
||||||
else:
|
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
|
|
||||||
self.t2 = values.get("lokr_t2", None)
|
|
||||||
|
|
||||||
if self.w1_b is not None:
|
|
||||||
self.rank = self.w1_b.shape[0]
|
|
||||||
elif self.w2_b is not None:
|
|
||||||
self.rank = self.w2_b.shape[0]
|
|
||||||
else:
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
self.check_keys(
|
|
||||||
values,
|
|
||||||
{
|
|
||||||
"lokr_w1",
|
|
||||||
"lokr_w1_a",
|
|
||||||
"lokr_w1_b",
|
|
||||||
"lokr_w2",
|
|
||||||
"lokr_w2_a",
|
|
||||||
"lokr_w2_b",
|
|
||||||
"lokr_t2",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
|
||||||
if w1 is None:
|
|
||||||
assert self.w1_a is not None
|
|
||||||
assert self.w1_b is not None
|
|
||||||
w1 = self.w1_a @ self.w1_b
|
|
||||||
|
|
||||||
w2 = self.w2
|
|
||||||
if w2 is None:
|
|
||||||
if self.t2 is None:
|
|
||||||
assert self.w2_a is not None
|
|
||||||
assert self.w2_b is not None
|
|
||||||
w2 = self.w2_a @ self.w2_b
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
w2 = w2.contiguous()
|
|
||||||
assert w1 is not None
|
|
||||||
assert w2 is not None
|
|
||||||
weight = torch.kron(w1, w2)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w1 is not None:
|
|
||||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
assert self.w1_a is not None
|
|
||||||
assert self.w1_b is not None
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w2 is not None:
|
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
assert self.w2_a is not None
|
|
||||||
assert self.w2_b is not None
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
|
||||||
# bias handled in LoRALayerBase(calc_size, to)
|
|
||||||
# weight: torch.Tensor
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.weight = values["diff"]
|
|
||||||
self.bias = values.get("diff_b", None)
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
|
||||||
self.check_keys(values, {"diff", "diff_b"})
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class IA3Layer(LoRALayerBase):
|
|
||||||
# weight: torch.Tensor
|
|
||||||
# on_input: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.weight = values["weight"]
|
|
||||||
self.on_input = values["on_input"]
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
|
||||||
self.check_keys(values, {"weight", "on_input"})
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
weight = self.weight
|
|
||||||
if not self.on_input:
|
|
||||||
weight = weight.reshape(-1, 1)
|
|
||||||
assert orig_weight is not None
|
|
||||||
return orig_weight * weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
|
||||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class NormLayer(LoRALayerBase):
|
|
||||||
# bias handled in LoRALayerBase(calc_size, to)
|
|
||||||
# weight: torch.Tensor
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.weight = values["w_norm"]
|
|
||||||
self.bias = values.get("b_norm", None)
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
|
||||||
self.check_keys(values, {"w_norm", "b_norm"})
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|
||||||
_name: str
|
|
||||||
layers: Dict[str, AnyLoRALayer]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
layers: Dict[str, AnyLoRALayer],
|
|
||||||
):
|
|
||||||
self._name = name
|
|
||||||
self.layers = layers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
# TODO: try revert if exception?
|
|
||||||
for _key, layer in self.layers.items():
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for _, layer in self.layers.items():
|
|
||||||
model_size += layer.calc_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
||||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
|
||||||
|
|
||||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
|
||||||
diffusers format, then this function will have no effect.
|
|
||||||
|
|
||||||
This function is adapted from:
|
|
||||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
|
||||||
"""
|
|
||||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
|
||||||
not_converted_count = 0 # The number of keys that were not converted.
|
|
||||||
|
|
||||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
|
||||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
|
||||||
# `input_blocks_4_1_proj_in`.
|
|
||||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
|
||||||
stability_unet_keys.sort()
|
|
||||||
|
|
||||||
new_state_dict = {}
|
|
||||||
for full_key, value in state_dict.items():
|
|
||||||
if full_key.startswith("lora_unet_"):
|
|
||||||
search_key = full_key.replace("lora_unet_", "")
|
|
||||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
|
||||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
|
||||||
map_key = stability_unet_keys[position - 1]
|
|
||||||
# Now, check if the map_key *actually* matches the search_key.
|
|
||||||
if search_key.startswith(map_key):
|
|
||||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
|
||||||
new_state_dict[new_key] = value
|
|
||||||
converted_count += 1
|
|
||||||
else:
|
|
||||||
new_state_dict[full_key] = value
|
|
||||||
not_converted_count += 1
|
|
||||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
|
||||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
|
||||||
new_state_dict[full_key] = value
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
|
||||||
|
|
||||||
if converted_count > 0 and not_converted_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
|
||||||
f" not_converted={not_converted_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_state_dict
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
file_path: Union[str, Path],
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
base_model: Optional[BaseModelType] = None,
|
|
||||||
) -> Self:
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
dtype = dtype or torch.float32
|
|
||||||
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
name=file_path.stem,
|
|
||||||
layers={},
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
|
||||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
sd = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = cls._group_state(sd)
|
|
||||||
|
|
||||||
if base_model == BaseModelType.StableDiffusionXL:
|
|
||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
|
||||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
|
||||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
|
||||||
|
|
||||||
# lora and locon
|
|
||||||
if "lora_up.weight" in values:
|
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
|
||||||
|
|
||||||
# loha
|
|
||||||
elif "hada_w1_a" in values:
|
|
||||||
layer = LoHALayer(layer_key, values)
|
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
|
||||||
layer = LoKRLayer(layer_key, values)
|
|
||||||
|
|
||||||
# diff
|
|
||||||
elif "diff" in values:
|
|
||||||
layer = FullLayer(layer_key, values)
|
|
||||||
|
|
||||||
# ia3
|
|
||||||
elif "on_input" in values:
|
|
||||||
layer = IA3Layer(layer_key, values)
|
|
||||||
|
|
||||||
# norms
|
|
||||||
elif "w_norm" in values:
|
|
||||||
layer = NormLayer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
|
||||||
raise Exception("Unknown lora format!")
|
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
|
||||||
state_dict[layer_key].clear()
|
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
model.layers[layer_key] = layer
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
||||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
stem, leaf = key.split(".", 1)
|
|
||||||
if stem not in state_dict_groupped:
|
|
||||||
state_dict_groupped[stem] = {}
|
|
||||||
state_dict_groupped[stem][leaf] = value
|
|
||||||
|
|
||||||
return state_dict_groupped
|
|
||||||
|
|
||||||
|
|
||||||
# code from
|
|
||||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
|
||||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
|
||||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
|
||||||
unet_conversion_map_layer = []
|
|
||||||
|
|
||||||
for i in range(3): # num_blocks is 3 in sdxl
|
|
||||||
# loop over downblocks/upblocks
|
|
||||||
for j in range(2):
|
|
||||||
# loop over resnets/attentions for downblocks
|
|
||||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
|
||||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no attention layers in down_blocks.3
|
|
||||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
|
||||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(3):
|
|
||||||
# loop over resnets/attentions for upblocks
|
|
||||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
|
||||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
|
||||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
|
||||||
|
|
||||||
# if i > 0: commentout for sdxl
|
|
||||||
# no attention layers in up_blocks.0
|
|
||||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
|
||||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
|
||||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
|
||||||
|
|
||||||
if i < 3:
|
|
||||||
# no downsample in down_blocks.3
|
|
||||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
|
||||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
|
||||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
|
||||||
|
|
||||||
# no upsample in up_blocks.3
|
|
||||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
|
||||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
|
||||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
|
||||||
|
|
||||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
|
||||||
sd_mid_atn_prefix = "middle_block.1."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
|
||||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
|
||||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
|
||||||
|
|
||||||
unet_conversion_map_resnet = [
|
|
||||||
# (stable-diffusion, HF Diffusers)
|
|
||||||
("in_layers.0.", "norm1."),
|
|
||||||
("in_layers.2.", "conv1."),
|
|
||||||
("out_layers.0.", "norm2."),
|
|
||||||
("out_layers.3.", "conv2."),
|
|
||||||
("emb_layers.1.", "time_emb_proj."),
|
|
||||||
("skip_connection.", "conv_shortcut."),
|
|
||||||
]
|
|
||||||
|
|
||||||
unet_conversion_map = []
|
|
||||||
for sd, hf in unet_conversion_map_layer:
|
|
||||||
if "resnets" in hf:
|
|
||||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
|
||||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
|
||||||
else:
|
|
||||||
unet_conversion_map.append((sd, hf))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
|
||||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
|
||||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
|
||||||
|
|
||||||
for j in range(2):
|
|
||||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
|
||||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
|
||||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
|
||||||
|
|
||||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
|
||||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
|
||||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
|
||||||
|
|
||||||
return unet_conversion_map
|
|
||||||
|
|
||||||
|
|
||||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
|
||||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
|
||||||
}
|
|
||||||
0
invokeai/backend/lora/__init__.py
Normal file
0
invokeai/backend/lora/__init__.py
Normal file
0
invokeai/backend/lora/conversions/__init__.py
Normal file
0
invokeai/backend/lora/conversions/__init__.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
|
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||||
|
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
|
||||||
|
|
||||||
|
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||||
|
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||||
|
"""
|
||||||
|
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||||
|
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||||
|
|
||||||
|
# Next, check that this is likely a FLUX model by spot-checking a few keys.
|
||||||
|
expected_keys = [
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||||
|
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||||
|
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||||
|
]
|
||||||
|
all_expected_keys_present = all(k in state_dict for k in expected_keys)
|
||||||
|
|
||||||
|
return all_keys_in_peft_format and all_expected_keys_present
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
|
||||||
|
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||||
|
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||||
|
|
||||||
|
This function is based on:
|
||||||
|
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||||
|
"""
|
||||||
|
# Group keys by layer.
|
||||||
|
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||||
|
|
||||||
|
# Remove the "transformer." prefix from all keys.
|
||||||
|
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||||
|
|
||||||
|
# Constants for FLUX.1
|
||||||
|
num_double_layers = 19
|
||||||
|
num_single_layers = 38
|
||||||
|
# inner_dim = 3072
|
||||||
|
# mlp_ratio = 4.0
|
||||||
|
|
||||||
|
layers: dict[str, AnyLoRALayer] = {}
|
||||||
|
|
||||||
|
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||||
|
if src_key in grouped_state_dict:
|
||||||
|
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||||
|
layers[dst_key] = LoRALayer(
|
||||||
|
values={
|
||||||
|
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||||
|
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||||
|
"alpha": torch.tensor(alpha),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert len(src_layer_dict) == 0
|
||||||
|
|
||||||
|
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||||
|
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||||
|
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||||
|
"""
|
||||||
|
# We expect that either all src keys are present or none of them are. Verify this.
|
||||||
|
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||||
|
assert all(keys_present) or not any(keys_present)
|
||||||
|
|
||||||
|
# If none of the keys are present, return early.
|
||||||
|
if not any(keys_present):
|
||||||
|
return
|
||||||
|
|
||||||
|
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||||
|
sub_layers: list[LoRALayerBase] = []
|
||||||
|
for src_layer_dict in src_layer_dicts:
|
||||||
|
sub_layers.append(
|
||||||
|
LoRALayer(
|
||||||
|
values={
|
||||||
|
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||||
|
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||||
|
"alpha": torch.tensor(alpha),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(src_layer_dict) == 0
|
||||||
|
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||||
|
|
||||||
|
# time_text_embed.timestep_embedder -> time_in.
|
||||||
|
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||||
|
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
|
||||||
|
|
||||||
|
# time_text_embed.text_embedder -> vector_in.
|
||||||
|
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
|
||||||
|
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
|
||||||
|
|
||||||
|
# time_text_embed.guidance_embedder -> guidance_in.
|
||||||
|
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
|
||||||
|
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
|
||||||
|
|
||||||
|
# context_embedder -> txt_in.
|
||||||
|
add_lora_layer_if_present("context_embedder", "txt_in")
|
||||||
|
|
||||||
|
# x_embedder -> img_in.
|
||||||
|
add_lora_layer_if_present("x_embedder", "img_in")
|
||||||
|
|
||||||
|
# Double transformer blocks.
|
||||||
|
for i in range(num_double_layers):
|
||||||
|
# norms.
|
||||||
|
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
|
||||||
|
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
|
||||||
|
|
||||||
|
# Q, K, V
|
||||||
|
add_qkv_lora_layer_if_present(
|
||||||
|
[
|
||||||
|
f"transformer_blocks.{i}.attn.to_q",
|
||||||
|
f"transformer_blocks.{i}.attn.to_k",
|
||||||
|
f"transformer_blocks.{i}.attn.to_v",
|
||||||
|
],
|
||||||
|
f"double_blocks.{i}.img_attn.qkv",
|
||||||
|
)
|
||||||
|
add_qkv_lora_layer_if_present(
|
||||||
|
[
|
||||||
|
f"transformer_blocks.{i}.attn.add_q_proj",
|
||||||
|
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||||
|
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||||
|
],
|
||||||
|
f"double_blocks.{i}.txt_attn.qkv",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ff img_mlp
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"transformer_blocks.{i}.ff.net.0.proj",
|
||||||
|
f"double_blocks.{i}.img_mlp.0",
|
||||||
|
)
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"transformer_blocks.{i}.ff.net.2",
|
||||||
|
f"double_blocks.{i}.img_mlp.2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ff txt_mlp
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"transformer_blocks.{i}.ff_context.net.0.proj",
|
||||||
|
f"double_blocks.{i}.txt_mlp.0",
|
||||||
|
)
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"transformer_blocks.{i}.ff_context.net.2",
|
||||||
|
f"double_blocks.{i}.txt_mlp.2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# output projections.
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"transformer_blocks.{i}.attn.to_out.0",
|
||||||
|
f"double_blocks.{i}.img_attn.proj",
|
||||||
|
)
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"transformer_blocks.{i}.attn.to_add_out",
|
||||||
|
f"double_blocks.{i}.txt_attn.proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single transformer blocks.
|
||||||
|
for i in range(num_single_layers):
|
||||||
|
# norms
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"single_transformer_blocks.{i}.norm.linear",
|
||||||
|
f"single_blocks.{i}.modulation.lin",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q, K, V, mlp
|
||||||
|
add_qkv_lora_layer_if_present(
|
||||||
|
[
|
||||||
|
f"single_transformer_blocks.{i}.attn.to_q",
|
||||||
|
f"single_transformer_blocks.{i}.attn.to_k",
|
||||||
|
f"single_transformer_blocks.{i}.attn.to_v",
|
||||||
|
f"single_transformer_blocks.{i}.proj_mlp",
|
||||||
|
],
|
||||||
|
f"single_blocks.{i}.linear1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output projections.
|
||||||
|
add_lora_layer_if_present(
|
||||||
|
f"single_transformer_blocks.{i}.proj_out",
|
||||||
|
f"single_blocks.{i}.linear2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final layer.
|
||||||
|
add_lora_layer_if_present("proj_out", "final_layer.linear")
|
||||||
|
|
||||||
|
# Assert that all keys were processed.
|
||||||
|
assert len(grouped_state_dict) == 0
|
||||||
|
|
||||||
|
return LoRAModelRaw(layers=layers)
|
||||||
|
|
||||||
|
|
||||||
|
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
|
"""Groups the keys in the state dict by layer."""
|
||||||
|
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
for key in state_dict:
|
||||||
|
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
|
||||||
|
parts = key.rsplit(".", maxsplit=2)
|
||||||
|
layer_name = parts[0]
|
||||||
|
key_name = ".".join(parts[1:])
|
||||||
|
if layer_name not in layer_dict:
|
||||||
|
layer_dict[layer_name] = {}
|
||||||
|
layer_dict[layer_name][key_name] = state_dict[key]
|
||||||
|
return layer_dict
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import re
|
||||||
|
from typing import Any, Dict, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
|
||||||
|
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||||
|
# Example keys:
|
||||||
|
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||||
|
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||||
|
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||||
|
FLUX_KOHYA_KEY_REGEX = (
|
||||||
|
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||||
|
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||||
|
|
||||||
|
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||||
|
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||||
|
"""
|
||||||
|
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||||
|
# Group keys by layer.
|
||||||
|
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
layer_name, param_name = key.split(".", 1)
|
||||||
|
if layer_name not in grouped_state_dict:
|
||||||
|
grouped_state_dict[layer_name] = {}
|
||||||
|
grouped_state_dict[layer_name][param_name] = value
|
||||||
|
|
||||||
|
# Convert the state dict to the InvokeAI format.
|
||||||
|
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||||
|
|
||||||
|
# Create LoRA layers.
|
||||||
|
layers: dict[str, AnyLoRALayer] = {}
|
||||||
|
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||||
|
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||||
|
|
||||||
|
# Create and return the LoRAModelRaw.
|
||||||
|
return LoRAModelRaw(layers=layers)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||||
|
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||||
|
|
||||||
|
Example key conversions:
|
||||||
|
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||||
|
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||||
|
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||||
|
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||||
|
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||||
|
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def replace_func(match: re.Match[str]) -> str:
|
||||||
|
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||||
|
if match.group(4):
|
||||||
|
s += f".{match.group(4)}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
converted_dict: dict[str, T] = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||||
|
if match:
|
||||||
|
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||||
|
converted_dict[new_key] = v
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||||
|
|
||||||
|
return converted_dict
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
|
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||||
|
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||||
|
|
||||||
|
layers: dict[str, AnyLoRALayer] = {}
|
||||||
|
for layer_key, values in grouped_state_dict.items():
|
||||||
|
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||||
|
|
||||||
|
return LoRAModelRaw(layers=layers)
|
||||||
|
|
||||||
|
|
||||||
|
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||||
|
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
stem, leaf = key.split(".", 1)
|
||||||
|
if stem not in state_dict_groupped:
|
||||||
|
state_dict_groupped[stem] = {}
|
||||||
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
|
return state_dict_groupped
|
||||||
154
invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py
Normal file
154
invokeai/backend/lora/conversions/sdxl_lora_conversion_utils.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import bisect
|
||||||
|
from typing import Dict, List, Tuple, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||||
|
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||||
|
|
||||||
|
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||||
|
diffusers format, then this function will have no effect.
|
||||||
|
|
||||||
|
This function is adapted from:
|
||||||
|
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||||
|
"""
|
||||||
|
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||||
|
not_converted_count = 0 # The number of keys that were not converted.
|
||||||
|
|
||||||
|
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||||
|
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||||
|
# `input_blocks_4_1_proj_in`.
|
||||||
|
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||||
|
stability_unet_keys.sort()
|
||||||
|
|
||||||
|
new_state_dict: dict[str, T] = {}
|
||||||
|
for full_key, value in state_dict.items():
|
||||||
|
if full_key.startswith("lora_unet_"):
|
||||||
|
search_key = full_key.replace("lora_unet_", "")
|
||||||
|
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||||
|
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||||
|
map_key = stability_unet_keys[position - 1]
|
||||||
|
# Now, check if the map_key *actually* matches the search_key.
|
||||||
|
if search_key.startswith(map_key):
|
||||||
|
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
converted_count += 1
|
||||||
|
else:
|
||||||
|
new_state_dict[full_key] = value
|
||||||
|
not_converted_count += 1
|
||||||
|
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||||
|
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||||
|
new_state_dict[full_key] = value
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||||
|
|
||||||
|
if converted_count > 0 and not_converted_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||||
|
f" not_converted={not_converted_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
# code from
|
||||||
|
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
|
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||||
|
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||||
|
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
# if i > 0: commentout for sdxl
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0.", "norm1."),
|
||||||
|
("in_layers.2.", "conv1."),
|
||||||
|
("out_layers.0.", "norm2."),
|
||||||
|
("out_layers.3.", "conv2."),
|
||||||
|
("emb_layers.1.", "time_emb_proj."),
|
||||||
|
("skip_connection.", "conv_shortcut."),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map: list[tuple[str, str]] = []
|
||||||
|
for sd, hf in unet_conversion_map_layer:
|
||||||
|
if "resnets" in hf:
|
||||||
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||||
|
else:
|
||||||
|
unet_conversion_map.append((sd, hf))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||||
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||||
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||||
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||||
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||||
|
|
||||||
|
return unet_conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||||
|
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
|
||||||
|
}
|
||||||
0
invokeai/backend/lora/layers/__init__.py
Normal file
0
invokeai/backend/lora/layers/__init__.py
Normal file
11
invokeai/backend/lora/layers/any_lora_layer.py
Normal file
11
invokeai/backend/lora/layers/any_lora_layer.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||||
|
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||||
|
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||||
|
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||||
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||||
|
|
||||||
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||||
46
invokeai/backend/lora/layers/concatenated_lora_layer.py
Normal file
46
invokeai/backend/lora/layers/concatenated_lora_layer.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class ConcatenatedLoRALayer(LoRALayerBase):
|
||||||
|
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
|
||||||
|
|
||||||
|
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
|
||||||
|
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
|
||||||
|
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
|
||||||
|
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
|
||||||
|
super().__init__(values={})
|
||||||
|
|
||||||
|
self._lora_layers = lora_layers
|
||||||
|
self._concat_axis = concat_axis
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
|
||||||
|
# require this value, we will need to implement chunking of the original weight tensor here.
|
||||||
|
layer_weights = [lora_layer.get_weight(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||||
|
return torch.cat(layer_weights, dim=self._concat_axis)
|
||||||
|
|
||||||
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
|
||||||
|
# require this value, we will need to implement chunking of the original bias tensor here.
|
||||||
|
layer_biases = [lora_layer.get_bias(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||||
|
layer_bias_is_none = [layer_bias is None for layer_bias in layer_biases]
|
||||||
|
if any(layer_bias_is_none):
|
||||||
|
assert all(layer_bias_is_none)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Ignore the type error, because we have just verified that all layer biases are non-None.
|
||||||
|
return torch.cat(layer_biases, dim=self._concat_axis)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
return sum(lora_layer.calc_size() for lora_layer in self._lora_layers)
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
for lora_layer in self._lora_layers:
|
||||||
|
lora_layer.to(device=device, dtype=dtype)
|
||||||
36
invokeai/backend/lora/layers/full_layer.py
Normal file
36
invokeai/backend/lora/layers/full_layer.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
self.bias = values.get("diff_b", None)
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"diff", "diff_b"})
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
41
invokeai/backend/lora/layers/ia3_layer.py
Normal file
41
invokeai/backend/lora/layers/ia3_layer.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class IA3Layer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# on_input: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(values)
|
||||||
|
|
||||||
|
self.weight = values["weight"]
|
||||||
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"weight", "on_input"})
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
weight = self.weight
|
||||||
|
if not self.on_input:
|
||||||
|
weight = weight.reshape(-1, 1)
|
||||||
|
assert orig_weight is not None
|
||||||
|
return orig_weight * weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||||
68
invokeai/backend/lora/layers/loha_layer.py
Normal file
68
invokeai/backend/lora/layers/loha_layer.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class LoHALayer(LoRALayerBase):
|
||||||
|
# w1_a: torch.Tensor
|
||||||
|
# w1_b: torch.Tensor
|
||||||
|
# w2_a: torch.Tensor
|
||||||
|
# w2_b: torch.Tensor
|
||||||
|
# t1: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(self, values: Dict[str, torch.Tensor]):
|
||||||
|
super().__init__(values)
|
||||||
|
|
||||||
|
self.w1_a = values["hada_w1_a"]
|
||||||
|
self.w1_b = values["hada_w1_b"]
|
||||||
|
self.w2_a = values["hada_w2_a"]
|
||||||
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
self.t1 = values.get("hada_t1", None)
|
||||||
|
self.t2 = values.get("hada_t2", None)
|
||||||
|
|
||||||
|
self.rank = self.w1_b.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"hada_w1_a",
|
||||||
|
"hada_w1_b",
|
||||||
|
"hada_w2_a",
|
||||||
|
"hada_w2_b",
|
||||||
|
"hada_t1",
|
||||||
|
"hada_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.t1 is None:
|
||||||
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
|
else:
|
||||||
|
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||||
|
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||||
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t1 is not None:
|
||||||
|
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
113
invokeai/backend/lora/layers/lokr_layer.py
Normal file
113
invokeai/backend/lora/layers/lokr_layer.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class LoKRLayer(LoRALayerBase):
|
||||||
|
# w1: Optional[torch.Tensor] = None
|
||||||
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
|
# w1_b: Optional[torch.Tensor] = None
|
||||||
|
# w2: Optional[torch.Tensor] = None
|
||||||
|
# w2_a: Optional[torch.Tensor] = None
|
||||||
|
# w2_b: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(values)
|
||||||
|
|
||||||
|
self.w1 = values.get("lokr_w1", None)
|
||||||
|
if self.w1 is None:
|
||||||
|
self.w1_a = values["lokr_w1_a"]
|
||||||
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
else:
|
||||||
|
self.w1_b = None
|
||||||
|
self.w1_a = None
|
||||||
|
|
||||||
|
self.w2 = values.get("lokr_w2", None)
|
||||||
|
if self.w2 is None:
|
||||||
|
self.w2_a = values["lokr_w2_a"]
|
||||||
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
else:
|
||||||
|
self.w2_a = None
|
||||||
|
self.w2_b = None
|
||||||
|
|
||||||
|
self.t2 = values.get("lokr_t2", None)
|
||||||
|
|
||||||
|
if self.w1_b is not None:
|
||||||
|
self.rank = self.w1_b.shape[0]
|
||||||
|
elif self.w2_b is not None:
|
||||||
|
self.rank = self.w2_b.shape[0]
|
||||||
|
else:
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lokr_w1",
|
||||||
|
"lokr_w1_a",
|
||||||
|
"lokr_w1_b",
|
||||||
|
"lokr_w2",
|
||||||
|
"lokr_w2_a",
|
||||||
|
"lokr_w2_b",
|
||||||
|
"lokr_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
|
if w1 is None:
|
||||||
|
assert self.w1_a is not None
|
||||||
|
assert self.w1_b is not None
|
||||||
|
w1 = self.w1_a @ self.w1_b
|
||||||
|
|
||||||
|
w2 = self.w2
|
||||||
|
if w2 is None:
|
||||||
|
if self.t2 is None:
|
||||||
|
assert self.w2_a is not None
|
||||||
|
assert self.w2_b is not None
|
||||||
|
w2 = self.w2_a @ self.w2_b
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
assert w1 is not None
|
||||||
|
assert w2 is not None
|
||||||
|
weight = torch.kron(w1, w2)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w1 is not None:
|
||||||
|
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
assert self.w1_a is not None
|
||||||
|
assert self.w1_b is not None
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
assert self.w2_a is not None
|
||||||
|
assert self.w2_b is not None
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
58
invokeai/backend/lora/layers/lora_layer.py
Normal file
58
invokeai/backend/lora/layers/lora_layer.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: find and debug lora/locon with bias
|
||||||
|
class LoRALayer(LoRALayerBase):
|
||||||
|
# up: torch.Tensor
|
||||||
|
# mid: Optional[torch.Tensor]
|
||||||
|
# down: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(values)
|
||||||
|
|
||||||
|
self.up = values["lora_up.weight"]
|
||||||
|
self.down = values["lora_down.weight"]
|
||||||
|
self.mid = values.get("lora_mid.weight", None)
|
||||||
|
|
||||||
|
self.rank = self.down.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lora_up.weight",
|
||||||
|
"lora_down.weight",
|
||||||
|
"lora_mid.weight",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.mid is not None:
|
||||||
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
|
else:
|
||||||
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.up, self.mid, self.down]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.up = self.up.to(device=device, dtype=dtype)
|
||||||
|
self.down = self.down.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.mid is not None:
|
||||||
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
71
invokeai/backend/lora/layers/lora_layer_base.py
Normal file
71
invokeai/backend/lora/layers/lora_layer_base.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from typing import Dict, Optional, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBase:
|
||||||
|
# rank: Optional[int]
|
||||||
|
# alpha: Optional[float]
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def scale(self):
|
||||||
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
if "alpha" in values:
|
||||||
|
self.alpha = values["alpha"].item()
|
||||||
|
else:
|
||||||
|
self.alpha = None
|
||||||
|
|
||||||
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
|
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||||
|
values["bias_indices"],
|
||||||
|
values["bias_values"],
|
||||||
|
tuple(values["bias_size"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.rank = None # set in layer implementation
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
return self.bias
|
||||||
|
|
||||||
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
params = {"weight": self.get_weight(orig_module.weight)}
|
||||||
|
bias = self.get_bias(orig_module.bias)
|
||||||
|
if bias is not None:
|
||||||
|
params["bias"] = bias
|
||||||
|
return params
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for val in [self.bias]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||||
|
"""Log a warning if values contains unhandled keys."""
|
||||||
|
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||||
|
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||||
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
|
unknown_keys = set(values.keys()) - all_known_keys
|
||||||
|
if unknown_keys:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||||
|
)
|
||||||
36
invokeai/backend/lora/layers/norm_layer.py
Normal file
36
invokeai/backend/lora/layers/norm_layer.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
|
class NormLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(values)
|
||||||
|
|
||||||
|
self.weight = values["w_norm"]
|
||||||
|
self.bias = values.get("b_norm", None)
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"w_norm", "b_norm"})
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
33
invokeai/backend/lora/layers/utils.py
Normal file
33
invokeai/backend/lora/layers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||||
|
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||||
|
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||||
|
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||||
|
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||||
|
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||||
|
|
||||||
|
|
||||||
|
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||||
|
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||||
|
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||||
|
|
||||||
|
if "lora_up.weight" in state_dict:
|
||||||
|
# LoRA a.k.a LoCon
|
||||||
|
return LoRALayer(state_dict)
|
||||||
|
elif "hada_w1_a" in state_dict:
|
||||||
|
return LoHALayer(state_dict)
|
||||||
|
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||||
|
return LoKRLayer(state_dict)
|
||||||
|
elif "diff" in state_dict:
|
||||||
|
# Full a.k.a Diff
|
||||||
|
return FullLayer(state_dict)
|
||||||
|
elif "on_input" in state_dict:
|
||||||
|
return IA3Layer(state_dict)
|
||||||
|
elif "w_norm" in state_dict:
|
||||||
|
return NormLayer(state_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||||
22
invokeai/backend/lora/lora_model_raw.py
Normal file
22
invokeai/backend/lora/lora_model_raw.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) 2024 The InvokeAI Development team
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||||
|
def __init__(self, layers: Dict[str, AnyLoRALayer]):
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
for _key, layer in self.layers.items():
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for _, layer in self.layers.items():
|
||||||
|
model_size += layer.calc_size()
|
||||||
|
return model_size
|
||||||
148
invokeai/backend/lora/lora_patcher.py
Normal file
148
invokeai/backend/lora/lora_patcher.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
|
class LoraPatcher:
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_patches(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||||
|
prefix: str,
|
||||||
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
):
|
||||||
|
"""Apply one or more LoRA patches to a model within a context manager.
|
||||||
|
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
|
||||||
|
that the LoRA patches do not need to be loaded into memory all at once.
|
||||||
|
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
|
||||||
|
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
|
||||||
|
"""
|
||||||
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
|
try:
|
||||||
|
for patch, patch_weight in patches:
|
||||||
|
LoraPatcher.apply_lora_patch(
|
||||||
|
model=model,
|
||||||
|
prefix=prefix,
|
||||||
|
patch=patch,
|
||||||
|
patch_weight=patch_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del patch
|
||||||
|
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_lora_patch(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
patch: LoRAModelRaw,
|
||||||
|
patch_weight: float,
|
||||||
|
original_weights: OriginalWeightsStorage,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply a single LoRA patch to a model.
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param patch: LoRA model to patch in.
|
||||||
|
:param patch_weight: LoRA patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if patch_weight == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||||
|
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||||
|
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||||
|
# without searching, but some legacy code still uses flattened keys.
|
||||||
|
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||||
|
|
||||||
|
prefix_len = len(prefix)
|
||||||
|
|
||||||
|
for layer_key, layer in patch.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_key, module = LoraPatcher._get_submodule(
|
||||||
|
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||||
|
)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||||
|
param_key = module_key + "." + param_name
|
||||||
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
|
# Save original weight
|
||||||
|
original_weights.save(param_key, module_param)
|
||||||
|
|
||||||
|
if module_param.shape != lora_param_weight.shape:
|
||||||
|
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||||
|
|
||||||
|
lora_param_weight *= patch_weight * layer_scale
|
||||||
|
module_param += lora_param_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_submodule(
|
||||||
|
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||||
|
) -> tuple[str, torch.nn.Module]:
|
||||||
|
"""Get the submodule corresponding to the given layer key.
|
||||||
|
:param model: The model to search.
|
||||||
|
:param layer_key: The layer key to search for.
|
||||||
|
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
|
||||||
|
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
|
||||||
|
searching, but some legacy code still uses flattened keys.
|
||||||
|
:return: A tuple containing the module key and the submodule.
|
||||||
|
"""
|
||||||
|
if not layer_key_is_flattened:
|
||||||
|
return layer_key, model.get_submodule(layer_key)
|
||||||
|
|
||||||
|
# Handle flattened keys.
|
||||||
|
assert "." not in layer_key
|
||||||
|
|
||||||
|
module = model
|
||||||
|
module_key = ""
|
||||||
|
key_parts = layer_key.split("_")
|
||||||
|
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
|
while len(key_parts) > 0:
|
||||||
|
try:
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key += "." + submodule_name
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
except Exception:
|
||||||
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||||
|
|
||||||
|
return module_key, module
|
||||||
@@ -66,8 +66,9 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
return (model_base / config.path).resolve()
|
return (model_base / config.path).resolve()
|
||||||
|
|
||||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||||
|
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||||
try:
|
try:
|
||||||
return self._ram_cache.get(config.key, submodel_type)
|
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
return self._ram_cache.get(
|
return self._ram_cache.get(
|
||||||
key=config.key,
|
key=config.key,
|
||||||
submodel_type=submodel_type,
|
submodel_type=submodel_type,
|
||||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
stats_name=stats_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_size_fs(
|
def get_size_fs(
|
||||||
|
|||||||
@@ -128,7 +128,24 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def max_cache_size(self) -> float:
|
def max_cache_size(self) -> float:
|
||||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
"""Return the maximum size the RAM cache can grow to."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@max_cache_size.setter
|
||||||
|
@abstractmethod
|
||||||
|
def max_cache_size(self, value: float) -> None:
|
||||||
|
"""Set the cap on vram cache size."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def max_vram_cache_size(self) -> float:
|
||||||
|
"""Return the maximum size the VRAM cache can grow to."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@max_vram_cache_size.setter
|
||||||
|
@abstractmethod
|
||||||
|
def max_vram_cache_size(self, value: float) -> float:
|
||||||
|
"""Set the maximum size the VRAM cache can grow to."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
max_vram_cache_size: float,
|
max_vram_cache_size: float,
|
||||||
execution_device: torch.device = torch.device("cuda"),
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
|
precision: torch.dtype = torch.float16,
|
||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool = True,
|
||||||
log_memory_usage: bool = False,
|
log_memory_usage: bool = False,
|
||||||
logger: Optional[Logger] = None,
|
logger: Optional[Logger] = None,
|
||||||
@@ -81,11 +82,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
|
:param precision: Precision for loaded models [torch.float16]
|
||||||
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
behaviour.
|
behaviour.
|
||||||
|
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||||
"""
|
"""
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
@@ -130,6 +133,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
"""Set the cap on cache size."""
|
"""Set the cap on cache size."""
|
||||||
self._max_cache_size = value
|
self._max_cache_size = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_vram_cache_size(self) -> float:
|
||||||
|
"""Return the cap on vram cache size."""
|
||||||
|
return self._max_vram_cache_size
|
||||||
|
|
||||||
|
@max_vram_cache_size.setter
|
||||||
|
def max_vram_cache_size(self, value: float) -> None:
|
||||||
|
"""Set the cap on vram cache size."""
|
||||||
|
self._max_vram_cache_size = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stats(self) -> Optional[CacheStats]:
|
def stats(self) -> Optional[CacheStats]:
|
||||||
"""Return collected CacheStats object."""
|
"""Return collected CacheStats object."""
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
|
from invokeai.backend.model_manager.util.model_util import (
|
||||||
|
convert_bundle_to_flux_transformer_checkpoint,
|
||||||
|
)
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -190,6 +193,13 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = Flux(params[config.config_path])
|
model = Flux(params[config.config_path])
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
|
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||||
|
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||||
|
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||||
|
self._ram_cache.make_room(new_sd_size)
|
||||||
|
for k in sd.keys():
|
||||||
|
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
|
||||||
|
sd[k] = sd[k].to(torch.bfloat16)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -230,5 +240,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
model = Flux(params[config.config_path])
|
model = Flux(params[config.config_path])
|
||||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
|
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||||
|
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -5,8 +5,18 @@ from logging import Logger
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||||
|
lora_model_from_flux_diffusers_state_dict,
|
||||||
|
)
|
||||||
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||||
|
lora_model_from_flux_kohya_state_dict,
|
||||||
|
)
|
||||||
|
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||||
|
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -45,14 +55,33 @@ class LoRALoader(ModelLoader):
|
|||||||
raise ValueError("There are no submodels in a LoRA model.")
|
raise ValueError("There are no submodels in a LoRA model.")
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
assert self._model_base is not None
|
assert self._model_base is not None
|
||||||
model = LoRAModelRaw.from_checkpoint(
|
|
||||||
file_path=model_path,
|
# Load the state dict from the model file.
|
||||||
dtype=self._torch_dtype,
|
if model_path.suffix == ".safetensors":
|
||||||
base_model=self._model_base,
|
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||||
)
|
else:
|
||||||
|
state_dict = torch.load(model_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Apply state_dict key conversions, if necessary.
|
||||||
|
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||||
|
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||||
|
elif self._model_base == BaseModelType.Flux:
|
||||||
|
if config.format == ModelFormat.Diffusers:
|
||||||
|
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict)
|
||||||
|
elif config.format == ModelFormat.LyCORIS:
|
||||||
|
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||||
|
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||||
|
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||||
|
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||||
|
|
||||||
|
model.to(dtype=self._torch_dtype)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# override
|
|
||||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||||
self._model_base = config.base
|
self._model_base = config.base
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
|||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||||
|
is_state_dict_likely_in_flux_diffusers_format,
|
||||||
|
)
|
||||||
|
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@@ -108,6 +112,8 @@ class ModelProbe(object):
|
|||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
"T2IAdapter": ModelType.T2IAdapter,
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
"CLIPModel": ModelType.CLIPEmbed,
|
"CLIPModel": ModelType.CLIPEmbed,
|
||||||
|
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||||
|
"T5EncoderModel": ModelType.T5Encoder,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -224,14 +230,27 @@ class ModelProbe(object):
|
|||||||
ckpt = ckpt.get("state_dict", ckpt)
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
for key in [str(k) for k in ckpt.keys()]:
|
for key in [str(k) for k in ckpt.keys()]:
|
||||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
if key.startswith(
|
||||||
|
(
|
||||||
|
"cond_stage_model.",
|
||||||
|
"first_stage_model.",
|
||||||
|
"model.diffusion_model.",
|
||||||
|
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
|
||||||
|
"double_blocks.",
|
||||||
|
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
|
||||||
|
# This prefix is typically used to distinguish between multiple models bundled in a single file.
|
||||||
|
"model.diffusion_model.double_blocks.",
|
||||||
|
)
|
||||||
|
):
|
||||||
# Keys starting with double_blocks are associated with Flux models
|
# Keys starting with double_blocks are associated with Flux models
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||||
return ModelType.VAE
|
return ModelType.VAE
|
||||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||||
return ModelType.LoRA
|
return ModelType.LoRA
|
||||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||||
|
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||||
|
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||||
return ModelType.LoRA
|
return ModelType.LoRA
|
||||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
@@ -283,9 +302,16 @@ class ModelProbe(object):
|
|||||||
if (folder_path / "image_encoder.txt").exists():
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
return ModelType.IPAdapter
|
return ModelType.IPAdapter
|
||||||
|
|
||||||
i = folder_path / "model_index.json"
|
config_path = None
|
||||||
c = folder_path / "config.json"
|
for p in [
|
||||||
config_path = i if i.exists() else c if c.exists() else None
|
folder_path / "model_index.json", # pipeline
|
||||||
|
folder_path / "config.json", # most diffusers
|
||||||
|
folder_path / "text_encoder_2" / "config.json", # T5 text encoder
|
||||||
|
folder_path / "text_encoder" / "config.json", # T5 CLIP
|
||||||
|
]:
|
||||||
|
if p.exists():
|
||||||
|
config_path = p
|
||||||
|
break
|
||||||
|
|
||||||
if config_path:
|
if config_path:
|
||||||
with open(config_path, "r") as file:
|
with open(config_path, "r") as file:
|
||||||
@@ -328,7 +354,10 @@ class ModelProbe(object):
|
|||||||
# TODO: Decide between dev/schnell
|
# TODO: Decide between dev/schnell
|
||||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
if "guidance_in.out_layer.weight" in state_dict:
|
if (
|
||||||
|
"guidance_in.out_layer.weight" in state_dict
|
||||||
|
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
||||||
|
):
|
||||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||||
# Due to model type and format being the descriminator for model configs this
|
# Due to model type and format being the descriminator for model configs this
|
||||||
# is used rather than attempting to support flux with separate model types and format
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
@@ -336,7 +365,7 @@ class ModelProbe(object):
|
|||||||
config_file = "flux-dev"
|
config_file = "flux-dev"
|
||||||
else:
|
else:
|
||||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||||
# Due to model type and format being the descriminator for model configs this
|
# Due to model type and format being the discriminator for model configs this
|
||||||
# is used rather than attempting to support flux with separate model types and format
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
# If changed in the future, please fix me
|
# If changed in the future, please fix me
|
||||||
config_file = "flux-schnell"
|
config_file = "flux-schnell"
|
||||||
@@ -443,7 +472,10 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
|
|
||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
if (
|
||||||
|
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||||
|
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||||
|
):
|
||||||
return ModelFormat.BnbQuantizednf4b
|
return ModelFormat.BnbQuantizednf4b
|
||||||
return ModelFormat("checkpoint")
|
return ModelFormat("checkpoint")
|
||||||
|
|
||||||
@@ -470,7 +502,10 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
if (
|
||||||
|
"double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||||
|
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||||
|
):
|
||||||
return BaseModelType.Flux
|
return BaseModelType.Flux
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
@@ -525,12 +560,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
"""Class for LoRA checkpoints."""
|
"""Class for LoRA checkpoints."""
|
||||||
|
|
||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
return ModelFormat("lycoris")
|
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
|
||||||
|
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
|
||||||
|
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
|
||||||
|
# file, but the weight keys are in the diffusers format.
|
||||||
|
return ModelFormat.Diffusers
|
||||||
|
return ModelFormat.LyCORIS
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||||
token_vector_length = lora_token_vector_length(checkpoint)
|
self.checkpoint
|
||||||
|
):
|
||||||
|
return BaseModelType.Flux
|
||||||
|
|
||||||
|
# If we've gotten here, we assume that the model is a Stable Diffusion model.
|
||||||
|
token_vector_length = lora_token_vector_length(self.checkpoint)
|
||||||
if token_vector_length == 768:
|
if token_vector_length == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif token_vector_length == 1024:
|
elif token_vector_length == 1024:
|
||||||
@@ -747,8 +791,27 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class T5EncoderFolderProbe(FolderProbeBase):
|
class T5EncoderFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.Any
|
||||||
|
|
||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
return ModelFormat.T5Encoder
|
path = self.model_path / "text_encoder_2"
|
||||||
|
if (path / "model.safetensors.index.json").exists():
|
||||||
|
return ModelFormat.T5Encoder
|
||||||
|
files = list(path.glob("*.safetensors"))
|
||||||
|
if len(files) == 0:
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found")
|
||||||
|
|
||||||
|
# shortcut: look for the quantization in the name
|
||||||
|
if any(x for x in files if "llm_int8" in x.as_posix()):
|
||||||
|
return ModelFormat.BnbQuantizedLlmInt8b
|
||||||
|
|
||||||
|
# more reliable path: probe contents for a 'SCB' key
|
||||||
|
ckpt = read_checkpoint_meta(files[0], scan=True)
|
||||||
|
if any("SCB" in x for x in ckpt.keys()):
|
||||||
|
return ModelFormat.BnbQuantizedLlmInt8b
|
||||||
|
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format")
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(PipelineFolderProbe):
|
class ONNXFolderProbe(PipelineFolderProbe):
|
||||||
|
|||||||
@@ -133,3 +133,29 @@ def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[in
|
|||||||
break
|
break
|
||||||
|
|
||||||
return lora_token_vector_length
|
return lora_token_vector_length
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bundle_to_flux_transformer_checkpoint(
|
||||||
|
transformer_state_dict: dict[str, torch.Tensor],
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
original_state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
keys_to_remove: list[str] = []
|
||||||
|
|
||||||
|
for k, v in transformer_state_dict.items():
|
||||||
|
if not k.startswith("model.diffusion_model"):
|
||||||
|
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
|
||||||
|
continue
|
||||||
|
if k.endswith("scale"):
|
||||||
|
# Scale math must be done at bfloat16 due to our current flux model
|
||||||
|
# support limitations at inference time
|
||||||
|
v = v.to(dtype=torch.bfloat16)
|
||||||
|
new_key = k.replace("model.diffusion_model.", "")
|
||||||
|
original_state_dict[new_key] = v
|
||||||
|
keys_to_remove.append(k)
|
||||||
|
|
||||||
|
# Remove processed keys from the original dictionary, leaving others in case
|
||||||
|
# other model state dicts need to be pulled
|
||||||
|
for k in keys_to_remove:
|
||||||
|
del transformer_state_dict[k]
|
||||||
|
|
||||||
|
return original_state_dict
|
||||||
|
|||||||
@@ -5,32 +5,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import AnyModel
|
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
|
||||||
|
|
||||||
"""
|
|
||||||
loras = [
|
|
||||||
(lora_model1, 0.7),
|
|
||||||
(lora_model2, 0.4),
|
|
||||||
]
|
|
||||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
|
||||||
# unet with applied loras
|
|
||||||
# unmodified unet
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
@@ -54,95 +40,6 @@ class ModelPatcher:
|
|||||||
finally:
|
finally:
|
||||||
unet.set_attn_processor(unet_orig_processors)
|
unet.set_attn_processor(unet_orig_processors)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
||||||
assert "." not in lora_key
|
|
||||||
|
|
||||||
if not lora_key.startswith(prefix):
|
|
||||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
||||||
|
|
||||||
module = model
|
|
||||||
module_key = ""
|
|
||||||
key_parts = lora_key[len(prefix) :].split("_")
|
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
|
|
||||||
while len(key_parts) > 0:
|
|
||||||
try:
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key += "." + submodule_name
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
except Exception:
|
|
||||||
submodule_name += "_" + key_parts.pop(0)
|
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
||||||
|
|
||||||
return (module_key, module)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_unet(
|
|
||||||
cls,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
) -> Generator[None, None, None]:
|
|
||||||
with cls.apply_lora(
|
|
||||||
unet,
|
|
||||||
loras=loras,
|
|
||||||
prefix="lora_unet_",
|
|
||||||
cached_weights=cached_weights,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
) -> Generator[None, None, None]:
|
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora(
|
|
||||||
cls,
|
|
||||||
model: AnyModel,
|
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
|
||||||
prefix: str,
|
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
|
||||||
) -> Generator[None, None, None]:
|
|
||||||
"""
|
|
||||||
Apply one or more LoRAs to a model.
|
|
||||||
|
|
||||||
:param model: The model to patch.
|
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
|
||||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
|
||||||
"""
|
|
||||||
original_weights = OriginalWeightsStorage(cached_weights)
|
|
||||||
try:
|
|
||||||
for lora_model, lora_weight in loras:
|
|
||||||
LoRAExt.patch_model(
|
|
||||||
model=model,
|
|
||||||
prefix=prefix,
|
|
||||||
lora=lora_model,
|
|
||||||
lora_weight=lora_weight,
|
|
||||||
original_weights=original_weights,
|
|
||||||
)
|
|
||||||
del lora_model
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
with torch.no_grad():
|
|
||||||
for param_key, weight in original_weights.get_changed_weights():
|
|
||||||
model.get_parameter(param_key).copy_(weight)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ti(
|
def apply_ti(
|
||||||
@@ -282,26 +179,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
|
|
||||||
class ONNXModelPatcher:
|
class ONNXModelPatcher:
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_unet(
|
|
||||||
cls,
|
|
||||||
unet: OnnxRuntimeModel,
|
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
|
||||||
) -> None:
|
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: OnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
|
||||||
) -> None:
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
# based on
|
# based on
|
||||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||||
|
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
@@ -31,107 +30,14 @@ class LoRAExt(ExtensionBase):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
lora_model = self._node_context.models.load(self._model_id).model
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
self.patch_model(
|
assert isinstance(lora_model, LoRAModelRaw)
|
||||||
|
LoraPatcher.apply_lora_patch(
|
||||||
model=unet,
|
model=unet,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
lora=lora_model,
|
patch=lora_model,
|
||||||
lora_weight=self._weight,
|
patch_weight=self._weight,
|
||||||
original_weights=original_weights,
|
original_weights=original_weights,
|
||||||
)
|
)
|
||||||
del lora_model
|
del lora_model
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def patch_model(
|
|
||||||
cls,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
prefix: str,
|
|
||||||
lora: LoRAModelRaw,
|
|
||||||
lora_weight: float,
|
|
||||||
original_weights: OriginalWeightsStorage,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Apply one or more LoRAs to a model.
|
|
||||||
:param model: The model to patch.
|
|
||||||
:param lora: LoRA model to patch in.
|
|
||||||
:param lora_weight: LoRA patch weight.
|
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
|
||||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if lora_weight == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# assert lora.device.type == "cpu"
|
|
||||||
for layer_key, layer in lora.layers.items():
|
|
||||||
if not layer_key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
|
||||||
param_key = module_key + "." + param_name
|
|
||||||
module_param = module.get_parameter(param_name)
|
|
||||||
|
|
||||||
# save original weight
|
|
||||||
original_weights.save(param_key, module_param)
|
|
||||||
|
|
||||||
if module_param.shape != lora_param_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
|
||||||
|
|
||||||
lora_param_weight *= lora_weight * layer_scale
|
|
||||||
module_param += lora_param_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
||||||
assert "." not in lora_key
|
|
||||||
|
|
||||||
if not lora_key.startswith(prefix):
|
|
||||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
||||||
|
|
||||||
module = model
|
|
||||||
module_key = ""
|
|
||||||
key_parts = lora_key[len(prefix) :].split("_")
|
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
|
|
||||||
while len(key_parts) > 0:
|
|
||||||
try:
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key += "." + submodule_name
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
except Exception:
|
|
||||||
submodule_name += "_" + key_parts.pop(0)
|
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
||||||
|
|
||||||
return (module_key, module)
|
|
||||||
|
|||||||
@@ -12,6 +12,18 @@ module.exports = {
|
|||||||
'i18next/no-literal-string': 'error',
|
'i18next/no-literal-string': 'error',
|
||||||
// https://eslint.org/docs/latest/rules/no-console
|
// https://eslint.org/docs/latest/rules/no-console
|
||||||
'no-console': 'error',
|
'no-console': 'error',
|
||||||
|
// https://eslint.org/docs/latest/rules/no-promise-executor-return
|
||||||
|
'no-promise-executor-return': 'error',
|
||||||
|
// https://eslint.org/docs/latest/rules/require-await
|
||||||
|
'require-await': 'error',
|
||||||
|
'no-restricted-properties': [
|
||||||
|
'error',
|
||||||
|
{
|
||||||
|
object: 'crypto',
|
||||||
|
property: 'randomUUID',
|
||||||
|
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
overrides: [
|
overrides: [
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { PropsWithChildren, memo, useEffect } from 'react';
|
import { PropsWithChildren, memo, useEffect } from 'react';
|
||||||
import { modelChanged } from '../src/features/parameters/store/generationSlice';
|
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
|
||||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||||
/**
|
/**
|
||||||
@@ -10,7 +10,9 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
useGlobalModifiersInit();
|
useGlobalModifiersInit();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
dispatch(
|
||||||
|
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
|
||||||
|
);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return props.children;
|
return props.children;
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ const config: KnipConfig = {
|
|||||||
'src/services/api/schema.ts',
|
'src/services/api/schema.ts',
|
||||||
'src/features/nodes/types/v1/**',
|
'src/features/nodes/types/v1/**',
|
||||||
'src/features/nodes/types/v2/**',
|
'src/features/nodes/types/v2/**',
|
||||||
|
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||||
|
'src/features/controlLayers/konva/util.ts',
|
||||||
|
// TODO(psyche): restore HRF functionality?
|
||||||
|
'src/features/hrf/**',
|
||||||
],
|
],
|
||||||
ignoreBinaries: ['only-allow'],
|
ignoreBinaries: ['only-allow'],
|
||||||
paths: {
|
paths: {
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
"build": "pnpm run lint && vite build",
|
"build": "pnpm run lint && vite build",
|
||||||
"typegen": "node scripts/typegen.js",
|
"typegen": "node scripts/typegen.js",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:knip": "knip",
|
"lint:knip": "knip --tags=-knipignore",
|
||||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
"lint:prettier": "prettier --check .",
|
"lint:prettier": "prettier --check .",
|
||||||
@@ -52,18 +52,19 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@chakra-ui/react-use-size": "^2.1.0",
|
|
||||||
"@dagrejs/dagre": "^1.1.3",
|
"@dagrejs/dagre": "^1.1.3",
|
||||||
"@dagrejs/graphlib": "^2.2.3",
|
"@dagrejs/graphlib": "^2.2.3",
|
||||||
"@dnd-kit/core": "^6.1.0",
|
"@dnd-kit/core": "^6.1.0",
|
||||||
"@dnd-kit/sortable": "^8.0.0",
|
"@dnd-kit/sortable": "^8.0.0",
|
||||||
"@dnd-kit/utilities": "^3.2.2",
|
"@dnd-kit/utilities": "^3.2.2",
|
||||||
"@fontsource-variable/inter": "^5.0.20",
|
"@fontsource-variable/inter": "^5.0.20",
|
||||||
"@invoke-ai/ui-library": "^0.0.29",
|
"@invoke-ai/ui-library": "^0.0.33",
|
||||||
"@nanostores/react": "^0.7.3",
|
"@nanostores/react": "^0.7.3",
|
||||||
"@reduxjs/toolkit": "2.2.3",
|
"@reduxjs/toolkit": "2.2.3",
|
||||||
"@roarr/browser-log-writer": "^1.3.0",
|
"@roarr/browser-log-writer": "^1.3.0",
|
||||||
|
"async-mutex": "^0.5.0",
|
||||||
"chakra-react-select": "^4.9.1",
|
"chakra-react-select": "^4.9.1",
|
||||||
|
"cmdk": "^1.0.0",
|
||||||
"compare-versions": "^6.1.1",
|
"compare-versions": "^6.1.1",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"fracturedjsonjs": "^4.0.2",
|
"fracturedjsonjs": "^4.0.2",
|
||||||
@@ -74,6 +75,8 @@
|
|||||||
"jsondiffpatch": "^0.6.0",
|
"jsondiffpatch": "^0.6.0",
|
||||||
"konva": "^9.3.14",
|
"konva": "^9.3.14",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
|
"lru-cache": "^11.0.0",
|
||||||
|
"nanoid": "^5.0.7",
|
||||||
"nanostores": "^0.11.2",
|
"nanostores": "^0.11.2",
|
||||||
"new-github-issue-url": "^1.0.0",
|
"new-github-issue-url": "^1.0.0",
|
||||||
"overlayscrollbars": "^2.10.0",
|
"overlayscrollbars": "^2.10.0",
|
||||||
@@ -88,10 +91,8 @@
|
|||||||
"react-hotkeys-hook": "4.5.0",
|
"react-hotkeys-hook": "4.5.0",
|
||||||
"react-i18next": "^14.1.3",
|
"react-i18next": "^14.1.3",
|
||||||
"react-icons": "^5.2.1",
|
"react-icons": "^5.2.1",
|
||||||
"react-konva": "^18.2.10",
|
|
||||||
"react-redux": "9.1.2",
|
"react-redux": "9.1.2",
|
||||||
"react-resizable-panels": "^2.0.23",
|
"react-resizable-panels": "^2.1.2",
|
||||||
"react-select": "5.8.0",
|
|
||||||
"react-use": "^17.5.1",
|
"react-use": "^17.5.1",
|
||||||
"react-virtuoso": "^4.9.0",
|
"react-virtuoso": "^4.9.0",
|
||||||
"reactflow": "^11.11.4",
|
"reactflow": "^11.11.4",
|
||||||
@@ -102,9 +103,9 @@
|
|||||||
"roarr": "^7.21.1",
|
"roarr": "^7.21.1",
|
||||||
"serialize-error": "^11.0.3",
|
"serialize-error": "^11.0.3",
|
||||||
"socket.io-client": "^4.7.5",
|
"socket.io-client": "^4.7.5",
|
||||||
|
"stable-hash": "^0.0.4",
|
||||||
"use-debounce": "^10.0.2",
|
"use-debounce": "^10.0.2",
|
||||||
"use-device-pixel-ratio": "^1.1.2",
|
"use-device-pixel-ratio": "^1.1.2",
|
||||||
"use-image": "^1.1.1",
|
|
||||||
"uuid": "^10.0.0",
|
"uuid": "^10.0.0",
|
||||||
"zod": "^3.23.8",
|
"zod": "^3.23.8",
|
||||||
"zod-validation-error": "^3.3.1"
|
"zod-validation-error": "^3.3.1"
|
||||||
@@ -135,6 +136,7 @@
|
|||||||
"@vitest/coverage-v8": "^1.5.0",
|
"@vitest/coverage-v8": "^1.5.0",
|
||||||
"@vitest/ui": "^1.5.0",
|
"@vitest/ui": "^1.5.0",
|
||||||
"concurrently": "^8.2.2",
|
"concurrently": "^8.2.2",
|
||||||
|
"csstype": "^3.1.3",
|
||||||
"dpdm": "^3.14.0",
|
"dpdm": "^3.14.0",
|
||||||
"eslint": "^8.57.0",
|
"eslint": "^8.57.0",
|
||||||
"eslint-plugin-i18next": "^6.0.9",
|
"eslint-plugin-i18next": "^6.0.9",
|
||||||
|
|||||||
643
invokeai/frontend/web/pnpm-lock.yaml
generated
643
invokeai/frontend/web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -127,7 +127,14 @@
|
|||||||
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
|
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
|
||||||
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
|
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
|
||||||
"bulkDownloadFailed": "Download fehlgeschlagen",
|
"bulkDownloadFailed": "Download fehlgeschlagen",
|
||||||
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
|
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen",
|
||||||
|
"selectForCompare": "Zum Vergleichen auswählen",
|
||||||
|
"compareImage": "Bilder vergleichen",
|
||||||
|
"exitSearch": "Suche beenden",
|
||||||
|
"newestFirst": "Neueste zuerst",
|
||||||
|
"oldestFirst": "Älteste zuerst",
|
||||||
|
"openInViewer": "Im Viewer öffnen",
|
||||||
|
"swapImages": "Bilder tauschen"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Tastenkürzel",
|
"keyboardShortcuts": "Tastenkürzel",
|
||||||
@@ -631,7 +638,8 @@
|
|||||||
"archived": "Archiviert",
|
"archived": "Archiviert",
|
||||||
"noBoards": "Kein {boardType}} Ordner",
|
"noBoards": "Kein {boardType}} Ordner",
|
||||||
"hideBoards": "Ordner verstecken",
|
"hideBoards": "Ordner verstecken",
|
||||||
"viewBoards": "Ordner ansehen"
|
"viewBoards": "Ordner ansehen",
|
||||||
|
"deletedPrivateBoardsCannotbeRestored": "Gelöschte Boards können nicht wiederhergestellt werden. Wenn Sie „Nur Board löschen“ wählen, werden die Bilder in einen privaten, nicht kategorisierten Status für den Ersteller des Bildes versetzt."
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"showAdvanced": "Zeige Erweitert",
|
"showAdvanced": "Zeige Erweitert",
|
||||||
@@ -781,7 +789,9 @@
|
|||||||
"batchFieldValues": "Stapelverarbeitungswerte",
|
"batchFieldValues": "Stapelverarbeitungswerte",
|
||||||
"batchQueued": "Stapelverarbeitung eingereiht",
|
"batchQueued": "Stapelverarbeitung eingereiht",
|
||||||
"graphQueued": "Graph eingereiht",
|
"graphQueued": "Graph eingereiht",
|
||||||
"graphFailedToQueue": "Fehler beim Einreihen des Graphen"
|
"graphFailedToQueue": "Fehler beim Einreihen des Graphen",
|
||||||
|
"generations_one": "Generation",
|
||||||
|
"generations_other": "Generationen"
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"negativePrompt": "Negativ Beschreibung",
|
"negativePrompt": "Negativ Beschreibung",
|
||||||
@@ -1146,5 +1156,10 @@
|
|||||||
"noMatchingTriggers": "Keine passenden Trigger",
|
"noMatchingTriggers": "Keine passenden Trigger",
|
||||||
"addPromptTrigger": "Prompt-Trigger hinzufügen",
|
"addPromptTrigger": "Prompt-Trigger hinzufügen",
|
||||||
"compatibleEmbeddings": "Kompatible Einbettungen"
|
"compatibleEmbeddings": "Kompatible Einbettungen"
|
||||||
|
},
|
||||||
|
"ui": {
|
||||||
|
"tabs": {
|
||||||
|
"queue": "Warteschlange"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,7 @@
|
|||||||
"aboutDesc": "Using Invoke for work? Check out:",
|
"aboutDesc": "Using Invoke for work? Check out:",
|
||||||
"aboutHeading": "Own Your Creative Power",
|
"aboutHeading": "Own Your Creative Power",
|
||||||
"accept": "Accept",
|
"accept": "Accept",
|
||||||
|
"apply": "Apply",
|
||||||
"add": "Add",
|
"add": "Add",
|
||||||
"advanced": "Advanced",
|
"advanced": "Advanced",
|
||||||
"ai": "ai",
|
"ai": "ai",
|
||||||
@@ -92,6 +93,7 @@
|
|||||||
"copy": "Copy",
|
"copy": "Copy",
|
||||||
"copyError": "$t(gallery.copy) Error",
|
"copyError": "$t(gallery.copy) Error",
|
||||||
"on": "On",
|
"on": "On",
|
||||||
|
"off": "Off",
|
||||||
"or": "or",
|
"or": "or",
|
||||||
"checkpoint": "Checkpoint",
|
"checkpoint": "Checkpoint",
|
||||||
"communityLabel": "Community",
|
"communityLabel": "Community",
|
||||||
@@ -115,6 +117,7 @@
|
|||||||
"githubLabel": "Github",
|
"githubLabel": "Github",
|
||||||
"goTo": "Go to",
|
"goTo": "Go to",
|
||||||
"hotkeysLabel": "Hotkeys",
|
"hotkeysLabel": "Hotkeys",
|
||||||
|
"loadingImage": "Loading Image",
|
||||||
"imageFailedToLoad": "Unable to Load Image",
|
"imageFailedToLoad": "Unable to Load Image",
|
||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"inpaint": "inpaint",
|
"inpaint": "inpaint",
|
||||||
@@ -132,6 +135,7 @@
|
|||||||
"nodes": "Workflows",
|
"nodes": "Workflows",
|
||||||
"notInstalled": "Not $t(common.installed)",
|
"notInstalled": "Not $t(common.installed)",
|
||||||
"openInNewTab": "Open in New Tab",
|
"openInNewTab": "Open in New Tab",
|
||||||
|
"openInViewer": "Open in Viewer",
|
||||||
"orderBy": "Order By",
|
"orderBy": "Order By",
|
||||||
"outpaint": "outpaint",
|
"outpaint": "outpaint",
|
||||||
"outputs": "Outputs",
|
"outputs": "Outputs",
|
||||||
@@ -162,10 +166,10 @@
|
|||||||
"alpha": "Alpha",
|
"alpha": "Alpha",
|
||||||
"selected": "Selected",
|
"selected": "Selected",
|
||||||
"tab": "Tab",
|
"tab": "Tab",
|
||||||
"viewing": "Viewing",
|
"view": "View",
|
||||||
"viewingDesc": "Review images in a large gallery view",
|
"viewDesc": "Review images in a large gallery view",
|
||||||
"editing": "Editing",
|
"edit": "Edit",
|
||||||
"editingDesc": "Edit on the Control Layers canvas",
|
"editDesc": "Edit on the Canvas",
|
||||||
"comparing": "Comparing",
|
"comparing": "Comparing",
|
||||||
"comparingDesc": "Comparing two images",
|
"comparingDesc": "Comparing two images",
|
||||||
"enabled": "Enabled",
|
"enabled": "Enabled",
|
||||||
@@ -325,6 +329,14 @@
|
|||||||
"canceled": "Canceled",
|
"canceled": "Canceled",
|
||||||
"completedIn": "Completed in",
|
"completedIn": "Completed in",
|
||||||
"batch": "Batch",
|
"batch": "Batch",
|
||||||
|
"origin": "Origin",
|
||||||
|
"destination": "Destination",
|
||||||
|
"upscaling": "Upscaling",
|
||||||
|
"canvas": "Canvas",
|
||||||
|
"generation": "Generation",
|
||||||
|
"workflows": "Workflows",
|
||||||
|
"other": "Other",
|
||||||
|
"gallery": "Gallery",
|
||||||
"batchFieldValues": "Batch Field Values",
|
"batchFieldValues": "Batch Field Values",
|
||||||
"item": "Item",
|
"item": "Item",
|
||||||
"session": "Session",
|
"session": "Session",
|
||||||
@@ -363,6 +375,7 @@
|
|||||||
"useCache": "Use Cache"
|
"useCache": "Use Cache"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
|
"gallery": "Gallery",
|
||||||
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
|
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
|
||||||
"assets": "Assets",
|
"assets": "Assets",
|
||||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
|
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
|
||||||
@@ -375,11 +388,11 @@
|
|||||||
"deleteImage_one": "Delete Image",
|
"deleteImage_one": "Delete Image",
|
||||||
"deleteImage_other": "Delete {{count}} Images",
|
"deleteImage_other": "Delete {{count}} Images",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
"displayBoardSearch": "Display Board Search",
|
"displayBoardSearch": "Board Search",
|
||||||
"displaySearch": "Display Search",
|
"displaySearch": "Image Search",
|
||||||
"download": "Download",
|
"download": "Download",
|
||||||
"exitBoardSearch": "Exit Board Search",
|
"exitBoardSearch": "Exit Board Search",
|
||||||
"exitSearch": "Exit Search",
|
"exitSearch": "Exit Image Search",
|
||||||
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
|
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
|
||||||
"galleryImageSize": "Image Size",
|
"galleryImageSize": "Image Size",
|
||||||
"gallerySettings": "Gallery Settings",
|
"gallerySettings": "Gallery Settings",
|
||||||
@@ -425,7 +438,8 @@
|
|||||||
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||||
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||||
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
|
||||||
|
"toggleMiniViewer": "Toggle Mini Viewer"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"searchHotkeys": "Search Hotkeys",
|
"searchHotkeys": "Search Hotkeys",
|
||||||
@@ -1004,6 +1018,8 @@
|
|||||||
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
|
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
|
||||||
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
|
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
|
||||||
"noModelSelected": "No model selected",
|
"noModelSelected": "No model selected",
|
||||||
|
"canvasManagerNotLoaded": "Canvas Manager not loaded",
|
||||||
|
"canvasBusy": "Canvas is busy",
|
||||||
"noPrompts": "No prompts generated",
|
"noPrompts": "No prompts generated",
|
||||||
"noNodesInGraph": "No nodes in graph",
|
"noNodesInGraph": "No nodes in graph",
|
||||||
"systemDisconnected": "System disconnected",
|
"systemDisconnected": "System disconnected",
|
||||||
@@ -1035,12 +1051,11 @@
|
|||||||
"scaledHeight": "Scaled H",
|
"scaledHeight": "Scaled H",
|
||||||
"scaledWidth": "Scaled W",
|
"scaledWidth": "Scaled W",
|
||||||
"scheduler": "Scheduler",
|
"scheduler": "Scheduler",
|
||||||
"seamlessXAxis": "Seamless Tiling X Axis",
|
"seamlessXAxis": "Seamless X Axis",
|
||||||
"seamlessYAxis": "Seamless Tiling Y Axis",
|
"seamlessYAxis": "Seamless Y Axis",
|
||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
"imageActions": "Image Actions",
|
"imageActions": "Image Actions",
|
||||||
"sendToImg2Img": "Send to Image to Image",
|
"sendToCanvas": "Send To Canvas",
|
||||||
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
|
||||||
"sendToUpscale": "Send To Upscale",
|
"sendToUpscale": "Send To Upscale",
|
||||||
"showOptionsPanel": "Show Side Panel (O or T)",
|
"showOptionsPanel": "Show Side Panel (O or T)",
|
||||||
"shuffle": "Shuffle Seed",
|
"shuffle": "Shuffle Seed",
|
||||||
@@ -1100,7 +1115,6 @@
|
|||||||
"confirmOnDelete": "Confirm On Delete",
|
"confirmOnDelete": "Confirm On Delete",
|
||||||
"developer": "Developer",
|
"developer": "Developer",
|
||||||
"displayInProgress": "Display Progress Images",
|
"displayInProgress": "Display Progress Images",
|
||||||
"enableImageDebugging": "Enable Image Debugging",
|
|
||||||
"enableInformationalPopovers": "Enable Informational Popovers",
|
"enableInformationalPopovers": "Enable Informational Popovers",
|
||||||
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
||||||
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
||||||
@@ -1182,8 +1196,8 @@
|
|||||||
"problemSavingMaskDesc": "Unable to export mask",
|
"problemSavingMaskDesc": "Unable to export mask",
|
||||||
"prunedQueue": "Pruned Queue",
|
"prunedQueue": "Pruned Queue",
|
||||||
"resetInitialImage": "Reset Initial Image",
|
"resetInitialImage": "Reset Initial Image",
|
||||||
"sentToImageToImage": "Sent To Image To Image",
|
"sentToCanvas": "Sent to Canvas",
|
||||||
"sentToUnifiedCanvas": "Sent to Unified Canvas",
|
"sentToUpscale": "Sent to Upscale",
|
||||||
"serverError": "Server Error",
|
"serverError": "Server Error",
|
||||||
"sessionRef": "Session: {{sessionId}}",
|
"sessionRef": "Session: {{sessionId}}",
|
||||||
"setAsCanvasInitialImage": "Set as canvas initial image",
|
"setAsCanvasInitialImage": "Set as canvas initial image",
|
||||||
@@ -1567,7 +1581,7 @@
|
|||||||
"copyToClipboard": "Copy to Clipboard",
|
"copyToClipboard": "Copy to Clipboard",
|
||||||
"cursorPosition": "Cursor Position",
|
"cursorPosition": "Cursor Position",
|
||||||
"darkenOutsideSelection": "Darken Outside Selection",
|
"darkenOutsideSelection": "Darken Outside Selection",
|
||||||
"discardAll": "Discard All",
|
"discardAll": "Discard All & Cancel Pending Generations",
|
||||||
"discardCurrent": "Discard Current",
|
"discardCurrent": "Discard Current",
|
||||||
"downloadAsImage": "Download As Image",
|
"downloadAsImage": "Download As Image",
|
||||||
"enableMask": "Enable Mask",
|
"enableMask": "Enable Mask",
|
||||||
@@ -1645,39 +1659,187 @@
|
|||||||
"storeNotInitialized": "Store is not initialized"
|
"storeNotInitialized": "Store is not initialized"
|
||||||
},
|
},
|
||||||
"controlLayers": {
|
"controlLayers": {
|
||||||
"deleteAll": "Delete All",
|
"bookmark": "Bookmark for Quick Switch",
|
||||||
|
"fitBboxToLayers": "Fit Bbox To Layers",
|
||||||
|
"removeBookmark": "Remove Bookmark",
|
||||||
|
"saveCanvasToGallery": "Save Canvas To Gallery",
|
||||||
|
"saveBboxToGallery": "Save Bbox To Gallery",
|
||||||
|
"savedToGalleryOk": "Saved to Gallery",
|
||||||
|
"savedToGalleryError": "Error saving to gallery",
|
||||||
|
"mergeVisible": "Merge Visible",
|
||||||
|
"mergeVisibleOk": "Merged visible layers",
|
||||||
|
"mergeVisibleError": "Error merging visible layers",
|
||||||
|
"clearHistory": "Clear History",
|
||||||
|
"generateMode": "Generate",
|
||||||
|
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||||
|
"composeMode": "Compose",
|
||||||
|
"composeModeDesc": "Compose your work iterative. Generated images are added back to the canvas.",
|
||||||
|
"autoSave": "Auto-save to Gallery",
|
||||||
|
"resetCanvas": "Reset Canvas",
|
||||||
|
"resetAll": "Reset All",
|
||||||
|
"clearCaches": "Clear Caches",
|
||||||
|
"recalculateRects": "Recalculate Rects",
|
||||||
|
"clipToBbox": "Clip Strokes to Bbox",
|
||||||
|
"compositeMaskedRegions": "Composite Masked Regions",
|
||||||
"addLayer": "Add Layer",
|
"addLayer": "Add Layer",
|
||||||
|
"duplicate": "Duplicate",
|
||||||
"moveToFront": "Move to Front",
|
"moveToFront": "Move to Front",
|
||||||
"moveToBack": "Move to Back",
|
"moveToBack": "Move to Back",
|
||||||
"moveForward": "Move Forward",
|
"moveForward": "Move Forward",
|
||||||
"moveBackward": "Move Backward",
|
"moveBackward": "Move Backward",
|
||||||
"brushSize": "Brush Size",
|
"brushSize": "Brush Size",
|
||||||
|
"width": "Width",
|
||||||
|
"zoom": "Zoom",
|
||||||
|
"resetView": "Reset View",
|
||||||
"controlLayers": "Control Layers",
|
"controlLayers": "Control Layers",
|
||||||
"globalMaskOpacity": "Global Mask Opacity",
|
"globalMaskOpacity": "Global Mask Opacity",
|
||||||
"autoNegative": "Auto Negative",
|
"autoNegative": "Auto Negative",
|
||||||
|
"enableAutoNegative": "Enable Auto Negative",
|
||||||
|
"disableAutoNegative": "Disable Auto Negative",
|
||||||
"deletePrompt": "Delete Prompt",
|
"deletePrompt": "Delete Prompt",
|
||||||
"resetRegion": "Reset Region",
|
"resetRegion": "Reset Region",
|
||||||
"debugLayers": "Debug Layers",
|
"debugLayers": "Debug Layers",
|
||||||
|
"showHUD": "Show HUD",
|
||||||
"rectangle": "Rectangle",
|
"rectangle": "Rectangle",
|
||||||
"maskPreviewColor": "Mask Preview Color",
|
"maskFill": "Mask Fill",
|
||||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||||
"regionalGuidance": "Regional Guidance",
|
"addRasterLayer": "Add $t(controlLayers.rasterLayer)",
|
||||||
|
"addControlLayer": "Add $t(controlLayers.controlLayer)",
|
||||||
|
"addInpaintMask": "Add $t(controlLayers.inpaintMask)",
|
||||||
|
"addRegionalGuidance": "Add $t(controlLayers.regionalGuidance)",
|
||||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||||
|
"raster": "Raster",
|
||||||
|
"rasterLayer": "Raster Layer",
|
||||||
|
"controlLayer": "Control Layer",
|
||||||
|
"inpaintMask": "Inpaint Mask",
|
||||||
|
"regionalGuidance": "Regional Guidance",
|
||||||
|
"ipAdapter": "IP Adapter",
|
||||||
|
"sendingToCanvas": "Sending to Canvas",
|
||||||
|
"sendingToGallery": "Sending to Gallery",
|
||||||
|
"sendToGallery": "Send To Gallery",
|
||||||
|
"sendToGalleryDesc": "Generations will be sent to the gallery.",
|
||||||
|
"sendToCanvas": "Send To Canvas",
|
||||||
|
"sendToCanvasDesc": "Generations will be staged onto the canvas.",
|
||||||
|
"rasterLayer_withCount_one": "$t(controlLayers.rasterLayer)",
|
||||||
|
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
|
||||||
|
"inpaintMask_withCount_one": "$t(controlLayers.inpaintMask)",
|
||||||
|
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
|
||||||
|
"ipAdapter_withCount_one": "$t(controlLayers.ipAdapter)",
|
||||||
|
"rasterLayer_withCount_other": "Raster Layers",
|
||||||
|
"controlLayer_withCount_other": "Control Layers",
|
||||||
|
"inpaintMask_withCount_other": "Inpaint Masks",
|
||||||
|
"regionalGuidance_withCount_other": "Regional Guidance",
|
||||||
|
"ipAdapter_withCount_other": "IP Adapters",
|
||||||
"opacity": "Opacity",
|
"opacity": "Opacity",
|
||||||
|
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||||
|
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||||
|
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||||
|
"globalIPAdapters_withCount_hidden": "Global IP Adapters ({{count}} hidden)",
|
||||||
|
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||||
|
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||||
|
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||||
|
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||||
|
"globalIPAdapters_withCount_visible": "Global IP Adapters ({{count}})",
|
||||||
|
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
|
||||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||||
"globalIPAdapter": "Global $t(common.ipAdapter)",
|
"globalIPAdapter": "Global $t(common.ipAdapter)",
|
||||||
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||||
"globalInitialImage": "Global Initial Image",
|
"globalInitialImage": "Global Initial Image",
|
||||||
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
||||||
|
"layer": "Layer",
|
||||||
"opacityFilter": "Opacity Filter",
|
"opacityFilter": "Opacity Filter",
|
||||||
"clearProcessor": "Clear Processor",
|
"clearProcessor": "Clear Processor",
|
||||||
"resetProcessor": "Reset Processor to Defaults",
|
"resetProcessor": "Reset Processor to Defaults",
|
||||||
"noLayersAdded": "No Layers Added",
|
"noLayersAdded": "No Layers Added",
|
||||||
"layers_one": "Layer",
|
"layer_one": "Layer",
|
||||||
"layers_other": "Layers"
|
"layer_other": "Layers",
|
||||||
|
"objects_zero": "empty",
|
||||||
|
"objects_one": "{{count}} object",
|
||||||
|
"objects_other": "{{count}} objects",
|
||||||
|
"convertToControlLayer": "Convert to Control Layer",
|
||||||
|
"convertToRasterLayer": "Convert to Raster Layer",
|
||||||
|
"transparency": "Transparency",
|
||||||
|
"enableTransparencyEffect": "Enable Transparency Effect",
|
||||||
|
"disableTransparencyEffect": "Disable Transparency Effect",
|
||||||
|
"hidingType": "Hiding {{type}}",
|
||||||
|
"showingType": "Showing {{type}}",
|
||||||
|
"dynamicGrid": "Dynamic Grid",
|
||||||
|
"logDebugInfo": "Log Debug Info",
|
||||||
|
"locked": "Locked",
|
||||||
|
"unlocked": "Unlocked",
|
||||||
|
"deleteSelected": "Delete Selected",
|
||||||
|
"deleteAll": "Delete All",
|
||||||
|
"flipHorizontal": "Flip Horizontal",
|
||||||
|
"flipVertical": "Flip Vertical",
|
||||||
|
"fill": {
|
||||||
|
"fillColor": "Fill Color",
|
||||||
|
"fillStyle": "Fill Style",
|
||||||
|
"solid": "Solid",
|
||||||
|
"grid": "Grid",
|
||||||
|
"crosshatch": "Crosshatch",
|
||||||
|
"vertical": "Vertical",
|
||||||
|
"horizontal": "Horizontal",
|
||||||
|
"diagonal": "Diagonal"
|
||||||
|
},
|
||||||
|
"tool": {
|
||||||
|
"brush": "Brush",
|
||||||
|
"eraser": "Eraser",
|
||||||
|
"rectangle": "Rectangle",
|
||||||
|
"bbox": "Bbox",
|
||||||
|
"move": "Move",
|
||||||
|
"view": "View",
|
||||||
|
"colorPicker": "Color Picker"
|
||||||
|
},
|
||||||
|
"filter": {
|
||||||
|
"filter": "Filter",
|
||||||
|
"filters": "Filters",
|
||||||
|
"filterType": "Filter Type",
|
||||||
|
"autoProcess": "Auto Process",
|
||||||
|
"reset": "Reset",
|
||||||
|
"process": "Process",
|
||||||
|
"apply": "Apply",
|
||||||
|
"cancel": "Cancel",
|
||||||
|
"spandrel": {
|
||||||
|
"label": "Image-to-Image Model",
|
||||||
|
"description": "Run an image-to-image model on the selected layer.",
|
||||||
|
"paramModel": "Model",
|
||||||
|
"paramAutoScale": "Auto Scale",
|
||||||
|
"paramAutoScaleDesc": "The selected model will be run until the target scale is reached.",
|
||||||
|
"paramScale": "Target Scale"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"transform": {
|
||||||
|
"transform": "Transform",
|
||||||
|
"fitToBbox": "Fit to Bbox",
|
||||||
|
"reset": "Reset",
|
||||||
|
"apply": "Apply",
|
||||||
|
"cancel": "Cancel"
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"snapToGrid": {
|
||||||
|
"label": "Snap to Grid",
|
||||||
|
"on": "On",
|
||||||
|
"off": "Off"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"HUD": {
|
||||||
|
"bbox": "Bbox",
|
||||||
|
"scaledBbox": "Scaled Bbox",
|
||||||
|
"autoSave": "Auto Save",
|
||||||
|
"entityStatus": {
|
||||||
|
"selectedEntity": "Selected Entity",
|
||||||
|
"selectedEntityIs": "Selected Entity is",
|
||||||
|
"isFiltering": "is filtering",
|
||||||
|
"isTransforming": "is transforming",
|
||||||
|
"isLocked": "is locked",
|
||||||
|
"isHidden": "is hidden",
|
||||||
|
"isDisabled": "is disabled",
|
||||||
|
"enabled": "Enabled"
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"upscaling": {
|
"upscaling": {
|
||||||
"upscale": "Upscale",
|
"upscale": "Upscale",
|
||||||
@@ -1765,5 +1927,30 @@
|
|||||||
"upscaling": "Upscaling",
|
"upscaling": "Upscaling",
|
||||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"system": {
|
||||||
|
"enableLogging": "Enable Logging",
|
||||||
|
"logLevel": {
|
||||||
|
"logLevel": "Log Level",
|
||||||
|
"trace": "Trace",
|
||||||
|
"debug": "Debug",
|
||||||
|
"info": "Info",
|
||||||
|
"warn": "Warn",
|
||||||
|
"error": "Error",
|
||||||
|
"fatal": "Fatal"
|
||||||
|
},
|
||||||
|
"logNamespaces": {
|
||||||
|
"logNamespaces": "Log Namespaces",
|
||||||
|
"gallery": "Gallery",
|
||||||
|
"models": "Models",
|
||||||
|
"config": "Config",
|
||||||
|
"canvas": "Canvas",
|
||||||
|
"generation": "Generation",
|
||||||
|
"workflows": "Workflows",
|
||||||
|
"system": "System",
|
||||||
|
"events": "Events",
|
||||||
|
"queue": "Queue",
|
||||||
|
"metadata": "Metadata"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,15 +86,15 @@
|
|||||||
"loadMore": "Cargar más",
|
"loadMore": "Cargar más",
|
||||||
"noImagesInGallery": "No hay imágenes para mostrar",
|
"noImagesInGallery": "No hay imágenes para mostrar",
|
||||||
"deleteImage_one": "Eliminar Imagen",
|
"deleteImage_one": "Eliminar Imagen",
|
||||||
"deleteImage_many": "",
|
"deleteImage_many": "Eliminar {{count}} Imágenes",
|
||||||
"deleteImage_other": "",
|
"deleteImage_other": "Eliminar {{count}} Imágenes",
|
||||||
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
||||||
"assets": "Activos",
|
"assets": "Activos",
|
||||||
"autoAssignBoardOnClick": "Asignación automática de tableros al hacer clic"
|
"autoAssignBoardOnClick": "Asignación automática de tableros al hacer clic"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Atajos de teclado",
|
"keyboardShortcuts": "Atajos de teclado",
|
||||||
"appHotkeys": "Atajos de applicación",
|
"appHotkeys": "Atajos de aplicación",
|
||||||
"generalHotkeys": "Atajos generales",
|
"generalHotkeys": "Atajos generales",
|
||||||
"galleryHotkeys": "Atajos de galería",
|
"galleryHotkeys": "Atajos de galería",
|
||||||
"unifiedCanvasHotkeys": "Atajos de lienzo unificado",
|
"unifiedCanvasHotkeys": "Atajos de lienzo unificado",
|
||||||
@@ -535,7 +535,7 @@
|
|||||||
"bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.",
|
"bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.",
|
||||||
"deleteBoardAndImages": "Borrar el panel y las imágenes",
|
"deleteBoardAndImages": "Borrar el panel y las imágenes",
|
||||||
"loading": "Cargando...",
|
"loading": "Cargando...",
|
||||||
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar",
|
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar. Al Seleccionar 'Borrar Solo el Panel' transferirá las imágenes a un estado sin categorizar.",
|
||||||
"move": "Mover",
|
"move": "Mover",
|
||||||
"menuItemAutoAdd": "Agregar automáticamente a este panel",
|
"menuItemAutoAdd": "Agregar automáticamente a este panel",
|
||||||
"searchBoard": "Buscando paneles…",
|
"searchBoard": "Buscando paneles…",
|
||||||
@@ -549,7 +549,13 @@
|
|||||||
"imagesWithCount_other": "{{count}} imágenes",
|
"imagesWithCount_other": "{{count}} imágenes",
|
||||||
"assetsWithCount_one": "{{count}} activo",
|
"assetsWithCount_one": "{{count}} activo",
|
||||||
"assetsWithCount_many": "{{count}} activos",
|
"assetsWithCount_many": "{{count}} activos",
|
||||||
"assetsWithCount_other": "{{count}} activos"
|
"assetsWithCount_other": "{{count}} activos",
|
||||||
|
"hideBoards": "Ocultar Paneles",
|
||||||
|
"addPrivateBoard": "Agregar un tablero privado",
|
||||||
|
"addSharedBoard": "Agregar Panel Compartido",
|
||||||
|
"boards": "Paneles",
|
||||||
|
"archiveBoard": "Archivar Panel",
|
||||||
|
"archived": "Archivado"
|
||||||
},
|
},
|
||||||
"accordions": {
|
"accordions": {
|
||||||
"compositing": {
|
"compositing": {
|
||||||
|
|||||||
@@ -496,7 +496,9 @@
|
|||||||
"main": "Principali",
|
"main": "Principali",
|
||||||
"noModelsInstalledDesc1": "Installa i modelli con",
|
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||||
"ipAdapters": "Adattatori IP",
|
"ipAdapters": "Adattatori IP",
|
||||||
"noMatchingModels": "Nessun modello corrispondente"
|
"noMatchingModels": "Nessun modello corrispondente",
|
||||||
|
"starterModelsInModelManager": "I modelli iniziali possono essere trovati in Gestione Modelli",
|
||||||
|
"spandrelImageToImage": "Immagine a immagine (Spandrel)"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Immagini",
|
"images": "Immagini",
|
||||||
@@ -510,7 +512,7 @@
|
|||||||
"perlinNoise": "Rumore Perlin",
|
"perlinNoise": "Rumore Perlin",
|
||||||
"type": "Tipo",
|
"type": "Tipo",
|
||||||
"strength": "Forza",
|
"strength": "Forza",
|
||||||
"upscaling": "Ampliamento",
|
"upscaling": "Amplia",
|
||||||
"scale": "Scala",
|
"scale": "Scala",
|
||||||
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
|
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
|
||||||
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
|
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
|
||||||
@@ -593,7 +595,7 @@
|
|||||||
"globalPositivePromptPlaceholder": "Prompt positivo globale",
|
"globalPositivePromptPlaceholder": "Prompt positivo globale",
|
||||||
"globalNegativePromptPlaceholder": "Prompt negativo globale",
|
"globalNegativePromptPlaceholder": "Prompt negativo globale",
|
||||||
"processImage": "Elabora Immagine",
|
"processImage": "Elabora Immagine",
|
||||||
"sendToUpscale": "Invia a Ampliare",
|
"sendToUpscale": "Invia a Amplia",
|
||||||
"postProcessing": "Post-elaborazione (Shift + U)"
|
"postProcessing": "Post-elaborazione (Shift + U)"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
@@ -1420,7 +1422,7 @@
|
|||||||
"paramUpscaleMethod": {
|
"paramUpscaleMethod": {
|
||||||
"heading": "Metodo di ampliamento",
|
"heading": "Metodo di ampliamento",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Metodo utilizzato per eseguire l'ampliamento dell'immagine per la correzione ad alta risoluzione."
|
"Metodo utilizzato per ampliare l'immagine per la correzione ad alta risoluzione."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"patchmatchDownScaleSize": {
|
"patchmatchDownScaleSize": {
|
||||||
@@ -1528,7 +1530,7 @@
|
|||||||
},
|
},
|
||||||
"upscaleModel": {
|
"upscaleModel": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Il modello di ampliamento (Upscale), scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
"Il modello di ampliamento, scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
||||||
],
|
],
|
||||||
"heading": "Modello di ampliamento"
|
"heading": "Modello di ampliamento"
|
||||||
},
|
},
|
||||||
@@ -1720,26 +1722,27 @@
|
|||||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||||
"queue": "Coda",
|
"queue": "Coda",
|
||||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||||
"upscaling": "Ampliamento",
|
"upscaling": "Amplia",
|
||||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"upscaling": {
|
"upscaling": {
|
||||||
"creativity": "Creatività",
|
"creativity": "Creatività",
|
||||||
"structure": "Struttura",
|
"structure": "Struttura",
|
||||||
"upscaleModel": "Modello di Ampliamento",
|
"upscaleModel": "Modello di ampliamento",
|
||||||
"scale": "Scala",
|
"scale": "Scala",
|
||||||
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
|
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
|
||||||
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
|
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
|
||||||
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
|
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
|
||||||
"upscaleModelDesc": "Modello per l'ampliamento (da immagine a immagine)",
|
"upscaleModelDesc": "Modello per l'ampliamento (immagine a immagine)",
|
||||||
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
|
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
|
||||||
"missingUpscaleModel": "Modello per l’ampliamento mancante",
|
"missingUpscaleModel": "Modello per l’ampliamento mancante",
|
||||||
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
|
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
|
||||||
"postProcessingModel": "Modello di post-elaborazione",
|
"postProcessingModel": "Modello di post-elaborazione",
|
||||||
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
|
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
|
||||||
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
|
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
|
||||||
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata."
|
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata.",
|
||||||
|
"upscale": "Amplia"
|
||||||
},
|
},
|
||||||
"upsell": {
|
"upsell": {
|
||||||
"inviteTeammates": "Invita collaboratori",
|
"inviteTeammates": "Invita collaboratori",
|
||||||
@@ -1789,6 +1792,7 @@
|
|||||||
"positivePromptColumn": "'prompt' o 'positive_prompt'",
|
"positivePromptColumn": "'prompt' o 'positive_prompt'",
|
||||||
"noTemplates": "Nessun modello",
|
"noTemplates": "Nessun modello",
|
||||||
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
|
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
|
||||||
"templateActions": "Azioni modello"
|
"templateActions": "Azioni modello",
|
||||||
|
"promptTemplateCleared": "Modello di prompt cancellato"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -501,7 +501,8 @@
|
|||||||
"noModelsInstalled": "Нет установленных моделей",
|
"noModelsInstalled": "Нет установленных моделей",
|
||||||
"noModelsInstalledDesc1": "Установите модели с помощью",
|
"noModelsInstalledDesc1": "Установите модели с помощью",
|
||||||
"noMatchingModels": "Нет подходящих моделей",
|
"noMatchingModels": "Нет подходящих моделей",
|
||||||
"ipAdapters": "IP адаптеры"
|
"ipAdapters": "IP адаптеры",
|
||||||
|
"starterModelsInModelManager": "Стартовые модели можно найти в Менеджере моделей"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Изображения",
|
"images": "Изображения",
|
||||||
@@ -1758,7 +1759,8 @@
|
|||||||
"postProcessingModel": "Модель постобработки",
|
"postProcessingModel": "Модель постобработки",
|
||||||
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
|
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
|
||||||
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
|
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
|
||||||
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
|
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img).",
|
||||||
|
"upscale": "Увеличить"
|
||||||
},
|
},
|
||||||
"stylePresets": {
|
"stylePresets": {
|
||||||
"noMatchingTemplates": "Нет подходящих шаблонов",
|
"noMatchingTemplates": "Нет подходящих шаблонов",
|
||||||
@@ -1804,7 +1806,8 @@
|
|||||||
"noTemplates": "Нет шаблонов",
|
"noTemplates": "Нет шаблонов",
|
||||||
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
|
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
|
||||||
"searchByName": "Поиск по имени",
|
"searchByName": "Поиск по имени",
|
||||||
"shared": "Общий"
|
"shared": "Общий",
|
||||||
|
"promptTemplateCleared": "Шаблон запроса создан"
|
||||||
},
|
},
|
||||||
"upsell": {
|
"upsell": {
|
||||||
"inviteTeammates": "Пригласите членов команды",
|
"inviteTeammates": "Пригласите членов команды",
|
||||||
|
|||||||
@@ -154,7 +154,8 @@
|
|||||||
"displaySearch": "显示搜索",
|
"displaySearch": "显示搜索",
|
||||||
"stretchToFit": "拉伸以适应",
|
"stretchToFit": "拉伸以适应",
|
||||||
"exitCompare": "退出对比",
|
"exitCompare": "退出对比",
|
||||||
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。"
|
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。",
|
||||||
|
"go": "运行"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "快捷键",
|
"keyboardShortcuts": "快捷键",
|
||||||
@@ -494,7 +495,9 @@
|
|||||||
"huggingFacePlaceholder": "所有者或模型名称",
|
"huggingFacePlaceholder": "所有者或模型名称",
|
||||||
"huggingFaceRepoID": "HuggingFace仓库ID",
|
"huggingFaceRepoID": "HuggingFace仓库ID",
|
||||||
"loraTriggerPhrases": "LoRA 触发词",
|
"loraTriggerPhrases": "LoRA 触发词",
|
||||||
"ipAdapters": "IP适配器"
|
"ipAdapters": "IP适配器",
|
||||||
|
"spandrelImageToImage": "图生图(Spandrel)",
|
||||||
|
"starterModelsInModelManager": "您可以在模型管理器中找到初始模型"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "图像",
|
"images": "图像",
|
||||||
@@ -695,7 +698,9 @@
|
|||||||
"outOfMemoryErrorDesc": "您当前的生成设置已超出系统处理能力.请调整设置后再次尝试.",
|
"outOfMemoryErrorDesc": "您当前的生成设置已超出系统处理能力.请调整设置后再次尝试.",
|
||||||
"parametersSet": "参数已恢复",
|
"parametersSet": "参数已恢复",
|
||||||
"errorCopied": "错误信息已复制",
|
"errorCopied": "错误信息已复制",
|
||||||
"modelImportCanceled": "模型导入已取消"
|
"modelImportCanceled": "模型导入已取消",
|
||||||
|
"importFailed": "导入失败",
|
||||||
|
"importSuccessful": "导入成功"
|
||||||
},
|
},
|
||||||
"unifiedCanvas": {
|
"unifiedCanvas": {
|
||||||
"layer": "图层",
|
"layer": "图层",
|
||||||
@@ -1705,12 +1710,55 @@
|
|||||||
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
|
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
|
||||||
"mainModelDesc": "主模型(SD1.5或SDXL架构)",
|
"mainModelDesc": "主模型(SD1.5或SDXL架构)",
|
||||||
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
|
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
|
||||||
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择."
|
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择.",
|
||||||
|
"upscale": "放大"
|
||||||
},
|
},
|
||||||
"upsell": {
|
"upsell": {
|
||||||
"inviteTeammates": "邀请团队成员",
|
"inviteTeammates": "邀请团队成员",
|
||||||
"professional": "专业",
|
"professional": "专业",
|
||||||
"professionalUpsell": "可在 Invoke 的专业版中使用.点击此处或访问 invoke.com/pricing 了解更多详情.",
|
"professionalUpsell": "可在 Invoke 的专业版中使用.点击此处或访问 invoke.com/pricing 了解更多详情.",
|
||||||
"shareAccess": "共享访问权限"
|
"shareAccess": "共享访问权限"
|
||||||
|
},
|
||||||
|
"stylePresets": {
|
||||||
|
"positivePrompt": "正向提示词",
|
||||||
|
"preview": "预览",
|
||||||
|
"deleteImage": "删除图像",
|
||||||
|
"deleteTemplate": "删除模版",
|
||||||
|
"deleteTemplate2": "您确定要删除这个模板吗?请注意,删除后无法恢复.",
|
||||||
|
"importTemplates": "导入提示模板,支持CSV或JSON格式",
|
||||||
|
"insertPlaceholder": "插入一个占位符",
|
||||||
|
"myTemplates": "我的模版",
|
||||||
|
"name": "名称",
|
||||||
|
"type": "类型",
|
||||||
|
"unableToDeleteTemplate": "无法删除提示模板",
|
||||||
|
"updatePromptTemplate": "更新提示词模版",
|
||||||
|
"exportPromptTemplates": "导出我的提示模板为CSV格式",
|
||||||
|
"exportDownloaded": "导出已下载",
|
||||||
|
"noMatchingTemplates": "无匹配的模版",
|
||||||
|
"promptTemplatesDesc1": "提示模板可以帮助您在编写提示时添加预设的文本内容.",
|
||||||
|
"promptTemplatesDesc3": "如果您没有使用占位符,那么模板的内容将会被添加到您提示的末尾.",
|
||||||
|
"searchByName": "按名称搜索",
|
||||||
|
"shared": "已分享",
|
||||||
|
"sharedTemplates": "已分享的模版",
|
||||||
|
"templateActions": "模版操作",
|
||||||
|
"templateDeleted": "提示模版已删除",
|
||||||
|
"toggleViewMode": "切换显示模式",
|
||||||
|
"uploadImage": "上传图像",
|
||||||
|
"active": "激活",
|
||||||
|
"choosePromptTemplate": "选择提示词模板",
|
||||||
|
"clearTemplateSelection": "清除模版选择",
|
||||||
|
"copyTemplate": "拷贝模版",
|
||||||
|
"createPromptTemplate": "创建提示词模版",
|
||||||
|
"defaultTemplates": "默认模版",
|
||||||
|
"editTemplate": "编辑模版",
|
||||||
|
"exportFailed": "无法生成并下载CSV文件",
|
||||||
|
"flatten": "将选定的模板内容合并到当前提示中",
|
||||||
|
"negativePrompt": "反向提示词",
|
||||||
|
"promptTemplateCleared": "提示模板已清除",
|
||||||
|
"useForTemplate": "用于提示词模版",
|
||||||
|
"viewList": "预览模版列表",
|
||||||
|
"viewModeTooltip": "这是您的提示在当前选定的模板下的预览效果。如需编辑提示,请直接在文本框中点击进行修改.",
|
||||||
|
"noTemplates": "无模版",
|
||||||
|
"private": "私密"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ async function generateTypes(schema) {
|
|||||||
process.stdout.write(`\nOK!\r\n`);
|
process.stdout.write(`\nOK!\r\n`);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function main() {
|
function main() {
|
||||||
const encoding = 'utf-8';
|
const encoding = 'utf-8';
|
||||||
|
|
||||||
if (process.stdin.isTTY) {
|
if (process.stdin.isTTY) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
|
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
|
||||||
|
import { useScopeFocusWatcher } from 'common/hooks/interactionScopes';
|
||||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||||
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
|
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
|
||||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||||
@@ -13,13 +14,16 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
|
|||||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||||
|
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||||
|
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||||
|
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import { AppContent } from 'features/ui/components/AppContent';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
|
import type { TabName } from 'features/ui/store/uiTypes';
|
||||||
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
|
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
|
||||||
import { AnimatePresence } from 'framer-motion';
|
import { AnimatePresence } from 'framer-motion';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
@@ -41,7 +45,7 @@ interface Props {
|
|||||||
};
|
};
|
||||||
selectedWorkflowId?: string;
|
selectedWorkflowId?: string;
|
||||||
selectedStylePresetId?: string;
|
selectedStylePresetId?: string;
|
||||||
destination?: InvokeTabName | undefined;
|
destination?: TabName;
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = ({
|
const App = ({
|
||||||
@@ -51,7 +55,7 @@ const App = ({
|
|||||||
selectedStylePresetId,
|
selectedStylePresetId,
|
||||||
destination,
|
destination,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(selectLanguage);
|
||||||
const logger = useLogger('system');
|
const logger = useLogger('system');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const clearStorage = useClearStorage();
|
const clearStorage = useClearStorage();
|
||||||
@@ -107,6 +111,7 @@ const App = ({
|
|||||||
|
|
||||||
useStarterModelsToast();
|
useStarterModelsToast();
|
||||||
useSyncQueueStatus();
|
useSyncQueueStatus();
|
||||||
|
useScopeFocusWatcher();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||||
@@ -119,7 +124,7 @@ const App = ({
|
|||||||
{...dropzone.getRootProps()}
|
{...dropzone.getRootProps()}
|
||||||
>
|
>
|
||||||
<input {...dropzone.getInputProps()} />
|
<input {...dropzone.getInputProps()} />
|
||||||
<InvokeTabs />
|
<AppContent />
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{dropzone.isDragActive && isHandlingUpload && (
|
{dropzone.isDragActive && isHandlingUpload && (
|
||||||
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
|
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
|
||||||
@@ -130,7 +135,10 @@ const App = ({
|
|||||||
<ChangeBoardModal />
|
<ChangeBoardModal />
|
||||||
<DynamicPromptsModal />
|
<DynamicPromptsModal />
|
||||||
<StylePresetModal />
|
<StylePresetModal />
|
||||||
|
<ClearQueueConfirmationsAlertDialog />
|
||||||
<PreselectedImage selectedImage={selectedImage} />
|
<PreselectedImage selectedImage={selectedImage} />
|
||||||
|
<SettingsModal />
|
||||||
|
<RefreshAfterResetModal />
|
||||||
</ErrorBoundary>
|
</ErrorBoundary>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import newGithubIssueUrl from 'new-github-issue-url';
|
import newGithubIssueUrl from 'new-github-issue-url';
|
||||||
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
||||||
@@ -13,9 +15,11 @@ type Props = {
|
|||||||
resetErrorBoundary: () => void;
|
resetErrorBoundary: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
|
||||||
|
|
||||||
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isLocal = useAppSelector((s) => s.config.isLocal);
|
const isLocal = useAppSelector(selectIsLocal);
|
||||||
|
|
||||||
const handleCopy = useCallback(() => {
|
const handleCopy = useCallback(() => {
|
||||||
const text = JSON.stringify(serializeError(error), null, 2);
|
const text = JSON.stringify(serializeError(error), null, 2);
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
|
|||||||
import Loading from 'common/components/Loading/Loading';
|
import Loading from 'common/components/Loading/Loading';
|
||||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { TabName } from 'features/ui/store/uiTypes';
|
||||||
import type { PropsWithChildren, ReactNode } from 'react';
|
import type { PropsWithChildren, ReactNode } from 'react';
|
||||||
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
@@ -46,7 +46,7 @@ interface Props extends PropsWithChildren {
|
|||||||
};
|
};
|
||||||
selectedWorkflowId?: string;
|
selectedWorkflowId?: string;
|
||||||
selectedStylePresetId?: string;
|
selectedStylePresetId?: string;
|
||||||
destination?: InvokeTabName;
|
destination?: TabName;
|
||||||
customStarUi?: CustomStarUi;
|
customStarUi?: CustomStarUi;
|
||||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||||
isDebugging?: boolean;
|
isDebugging?: boolean;
|
||||||
|
|||||||
@@ -21,10 +21,16 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
|||||||
direction,
|
direction,
|
||||||
shadows: {
|
shadows: {
|
||||||
..._theme.shadows,
|
..._theme.shadows,
|
||||||
|
selected:
|
||||||
|
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||||
|
hoverSelected:
|
||||||
|
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||||
|
hoverUnselected:
|
||||||
|
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||||
selectedForCompare:
|
selectedForCompare:
|
||||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||||
hoverSelectedForCompare:
|
hoverSelectedForCompare:
|
||||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}, [direction]);
|
}, [direction]);
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { $authToken } from 'app/store/nanostores/authToken';
|
import { $authToken } from 'app/store/nanostores/authToken';
|
||||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppStore } from 'app/store/nanostores/store';
|
||||||
import type { MapStore } from 'nanostores';
|
import type { MapStore } from 'nanostores';
|
||||||
import { atom, map } from 'nanostores';
|
import { atom, map } from 'nanostores';
|
||||||
import { useEffect, useMemo } from 'react';
|
import { useEffect, useMemo } from 'react';
|
||||||
@@ -18,14 +18,19 @@ declare global {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||||
|
|
||||||
|
export const $socket = atom<AppSocket | null>(null);
|
||||||
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||||
|
|
||||||
const $isSocketInitialized = atom<boolean>(false);
|
const $isSocketInitialized = atom<boolean>(false);
|
||||||
|
export const $isConnected = atom<boolean>(false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes the socket.io connection and sets up event listeners.
|
* Initializes the socket.io connection and sets up event listeners.
|
||||||
*/
|
*/
|
||||||
export const useSocketIO = () => {
|
export const useSocketIO = () => {
|
||||||
const dispatch = useAppDispatch();
|
const { dispatch, getState } = useAppStore();
|
||||||
const baseUrl = useStore($baseUrl);
|
const baseUrl = useStore($baseUrl);
|
||||||
const authToken = useStore($authToken);
|
const authToken = useStore($authToken);
|
||||||
const addlSocketOptions = useStore($socketOptions);
|
const addlSocketOptions = useStore($socketOptions);
|
||||||
@@ -61,8 +66,9 @@ export const useSocketIO = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
|
const socket: AppSocket = io(socketUrl, socketOptions);
|
||||||
setEventListeners({ dispatch, socket });
|
$socket.set(socket);
|
||||||
|
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
|
||||||
socket.connect();
|
socket.connect();
|
||||||
|
|
||||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||||
@@ -84,5 +90,5 @@ export const useSocketIO = () => {
|
|||||||
socket.disconnect();
|
socket.disconnect();
|
||||||
$isSocketInitialized.set(false);
|
$isSocketInitialized.set(false);
|
||||||
};
|
};
|
||||||
}, [dispatch, socketOptions, socketUrl]);
|
}, [dispatch, getState, socketOptions, socketUrl]);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -15,21 +15,21 @@ export const BASE_CONTEXT = {};
|
|||||||
|
|
||||||
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||||
|
|
||||||
export type LoggerNamespace =
|
export const zLogNamespace = z.enum([
|
||||||
| 'images'
|
'canvas',
|
||||||
| 'models'
|
'config',
|
||||||
| 'config'
|
'events',
|
||||||
| 'canvas'
|
'gallery',
|
||||||
| 'generation'
|
'generation',
|
||||||
| 'nodes'
|
'metadata',
|
||||||
| 'system'
|
'models',
|
||||||
| 'socketio'
|
'system',
|
||||||
| 'session'
|
'queue',
|
||||||
| 'queue'
|
'workflows',
|
||||||
| 'dnd'
|
]);
|
||||||
| 'controlLayers';
|
export type LogNamespace = z.infer<typeof zLogNamespace>;
|
||||||
|
|
||||||
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
|
export const logger = (namespace: LogNamespace) => $logger.get().child({ namespace });
|
||||||
|
|
||||||
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
|
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
|
||||||
export type LogLevel = z.infer<typeof zLogLevel>;
|
export type LogLevel = z.infer<typeof zLogLevel>;
|
||||||
|
|||||||
@@ -1,29 +1,41 @@
|
|||||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
selectSystemLogIsEnabled,
|
||||||
|
selectSystemLogLevel,
|
||||||
|
selectSystemLogNamespaces,
|
||||||
|
} from 'features/system/store/systemSlice';
|
||||||
import { useEffect, useMemo } from 'react';
|
import { useEffect, useMemo } from 'react';
|
||||||
import { ROARR, Roarr } from 'roarr';
|
import { ROARR, Roarr } from 'roarr';
|
||||||
|
|
||||||
import type { LoggerNamespace } from './logger';
|
import type { LogNamespace } from './logger';
|
||||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
||||||
|
|
||||||
export const useLogger = (namespace: LoggerNamespace) => {
|
export const useLogger = (namespace: LogNamespace) => {
|
||||||
const consoleLogLevel = useAppSelector((s) => s.system.consoleLogLevel);
|
const logLevel = useAppSelector(selectSystemLogLevel);
|
||||||
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
|
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
|
||||||
|
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
|
||||||
|
|
||||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (shouldLogToConsole) {
|
if (logIsEnabled) {
|
||||||
// Enable console log output
|
// Enable console log output
|
||||||
localStorage.setItem('ROARR_LOG', 'true');
|
localStorage.setItem('ROARR_LOG', 'true');
|
||||||
|
|
||||||
// Use a filter to show only logs of the given level
|
// Use a filter to show only logs of the given level
|
||||||
localStorage.setItem('ROARR_FILTER', `context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`);
|
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
|
||||||
|
if (logNamespaces.length > 0) {
|
||||||
|
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
|
||||||
|
} else {
|
||||||
|
filter += ' AND context.namespace:undefined';
|
||||||
|
}
|
||||||
|
localStorage.setItem('ROARR_FILTER', filter);
|
||||||
} else {
|
} else {
|
||||||
// Disable console log output
|
// Disable console log output
|
||||||
localStorage.setItem('ROARR_LOG', 'false');
|
localStorage.setItem('ROARR_LOG', 'false');
|
||||||
}
|
}
|
||||||
ROARR.write = createLogWriter();
|
ROARR.write = createLogWriter();
|
||||||
}, [consoleLogLevel, shouldLogToConsole]);
|
}, [logLevel, logIsEnabled, logNamespaces]);
|
||||||
|
|
||||||
// Update the module-scoped logger context as needed
|
// Update the module-scoped logger context as needed
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { TabName } from 'features/ui/store/uiTypes';
|
||||||
|
|
||||||
export const enqueueRequested = createAction<{
|
export const enqueueRequested = createAction<{
|
||||||
tabName: InvokeTabName;
|
tabName: TabName;
|
||||||
prepend: boolean;
|
prepend: boolean;
|
||||||
}>('app/enqueueRequested');
|
}>('app/enqueueRequested');
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
export const STORAGE_PREFIX = '@@invokeai-';
|
export const STORAGE_PREFIX = '@@invokeai-';
|
||||||
export const EMPTY_ARRAY = [];
|
export const EMPTY_ARRAY = [];
|
||||||
|
export const EMPTY_OBJECT = {};
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
||||||
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
|
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -19,3 +20,5 @@ export const getSelectorsOptions: GetSelectorsOptions = {
|
|||||||
argsMemoize: lruMemoize,
|
argsMemoize: lruMemoize,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const createMemoizedAppSelector = createMemoizedSelector.withTypes<RootState>();
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { PersistError, RehydrateError } from 'redux-remember';
|
import { PersistError, RehydrateError } from 'redux-remember';
|
||||||
import { serializeError } from 'serialize-error';
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
@@ -41,6 +40,6 @@ export const errorHandler = (err: PersistError | RehydrateError) => {
|
|||||||
} else if (err instanceof RehydrateError) {
|
} else if (err instanceof RehydrateError) {
|
||||||
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
|
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
|
||||||
} else {
|
} else {
|
||||||
log.error({ error: parseify(err) }, 'Problem in persistence layer');
|
log.error({ error: serializeError(err) }, 'Problem in persistence layer');
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
|
||||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
import type { Graph } from 'services/api/types';
|
import type { Graph } from 'services/api/types';
|
||||||
import { socketGeneratorProgress } from 'services/events/actions';
|
|
||||||
|
|
||||||
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||||
if (isAnyGraphBuilt(action)) {
|
if (isAnyGraphBuilt(action)) {
|
||||||
@@ -24,13 +22,5 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (socketGeneratorProgress.match(action)) {
|
|
||||||
const sanitized = deepClone(action);
|
|
||||||
if (sanitized.payload.data.progress_image) {
|
|
||||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
|
||||||
}
|
|
||||||
return sanitized;
|
|
||||||
}
|
|
||||||
|
|
||||||
return action;
|
return action;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||||
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||||
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||||
@@ -9,17 +9,7 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
|||||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||||
import { addCanvasCopiedToClipboardListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard';
|
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||||
import { addCanvasDownloadedAsImageListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage';
|
|
||||||
import { addCanvasImageToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet';
|
|
||||||
import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery';
|
|
||||||
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
|
|
||||||
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
|
|
||||||
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
|
|
||||||
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
|
|
||||||
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
|
|
||||||
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
|
|
||||||
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
|
|
||||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||||
@@ -37,16 +27,7 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
|
|||||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
|
||||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
|
||||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
|
||||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
|
||||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
|
||||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
|
||||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
|
||||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
|
||||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
|
||||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||||
import type { AppDispatch, RootState } from 'app/store/store';
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
@@ -60,6 +41,8 @@ export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
|||||||
|
|
||||||
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
||||||
|
|
||||||
|
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||||
*
|
*
|
||||||
@@ -83,7 +66,6 @@ addGalleryImageClickedListener(startAppListening);
|
|||||||
addGalleryOffsetChangedListener(startAppListening);
|
addGalleryOffsetChangedListener(startAppListening);
|
||||||
|
|
||||||
// User Invoked
|
// User Invoked
|
||||||
addEnqueueRequestedCanvasListener(startAppListening);
|
|
||||||
addEnqueueRequestedNodes(startAppListening);
|
addEnqueueRequestedNodes(startAppListening);
|
||||||
addEnqueueRequestedLinear(startAppListening);
|
addEnqueueRequestedLinear(startAppListening);
|
||||||
addEnqueueRequestedUpscale(startAppListening);
|
addEnqueueRequestedUpscale(startAppListening);
|
||||||
@@ -91,31 +73,22 @@ addAnyEnqueuedListener(startAppListening);
|
|||||||
addBatchEnqueuedListener(startAppListening);
|
addBatchEnqueuedListener(startAppListening);
|
||||||
|
|
||||||
// Canvas actions
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener(startAppListening);
|
// addCanvasSavedToGalleryListener(startAppListening);
|
||||||
addCanvasMaskSavedToGalleryListener(startAppListening);
|
// addCanvasMaskSavedToGalleryListener(startAppListening);
|
||||||
addCanvasImageToControlNetListener(startAppListening);
|
// addCanvasImageToControlNetListener(startAppListening);
|
||||||
addCanvasMaskToControlNetListener(startAppListening);
|
// addCanvasMaskToControlNetListener(startAppListening);
|
||||||
addCanvasDownloadedAsImageListener(startAppListening);
|
// addCanvasDownloadedAsImageListener(startAppListening);
|
||||||
addCanvasCopiedToClipboardListener(startAppListening);
|
// addCanvasCopiedToClipboardListener(startAppListening);
|
||||||
addCanvasMergedListener(startAppListening);
|
// addCanvasMergedListener(startAppListening);
|
||||||
addStagingAreaImageSavedListener(startAppListening);
|
// addStagingAreaImageSavedListener(startAppListening);
|
||||||
addCommitStagingAreaImageListener(startAppListening);
|
// addCommitStagingAreaImageListener(startAppListening);
|
||||||
|
addStagingListeners(startAppListening);
|
||||||
|
|
||||||
// Socket.IO
|
// Socket.IO
|
||||||
addGeneratorProgressEventListener(startAppListening);
|
|
||||||
addInvocationCompleteEventListener(startAppListening);
|
|
||||||
addInvocationErrorEventListener(startAppListening);
|
|
||||||
addInvocationStartedEventListener(startAppListening);
|
|
||||||
addSocketConnectedEventListener(startAppListening);
|
addSocketConnectedEventListener(startAppListening);
|
||||||
addSocketDisconnectedEventListener(startAppListening);
|
|
||||||
addModelLoadEventListener(startAppListening);
|
|
||||||
addModelInstallEventListener(startAppListening);
|
|
||||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
|
||||||
addBulkDownloadListeners(startAppListening);
|
|
||||||
|
|
||||||
// ControlNet
|
// Gallery bulk download
|
||||||
addControlNetImageProcessedListener(startAppListening);
|
addBulkDownloadListeners(startAppListening);
|
||||||
addControlNetAutoProcessListener(startAppListening);
|
|
||||||
|
|
||||||
// Boards
|
// Boards
|
||||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||||
@@ -148,4 +121,6 @@ addAdHocPostProcessingRequestedListener(startAppListening);
|
|||||||
addDynamicPromptsListener(startAppListening);
|
addDynamicPromptsListener(startAppListening);
|
||||||
|
|
||||||
addSetDefaultSettingsListener(startAppListening);
|
addSetDefaultSettingsListener(startAppListening);
|
||||||
addControlAdapterPreprocessor(startAppListening);
|
// addControlAdapterPreprocessor(startAppListening);
|
||||||
|
|
||||||
|
addCancellationsListeners(startAppListening);
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { parseify } from 'common/util/serialize';
|
import type { SerializableObject } from 'common/types';
|
||||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
const log = logger('queue');
|
||||||
|
|
||||||
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
|
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
|
||||||
|
|
||||||
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
|
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: adHocPostProcessingRequested,
|
actionCreator: adHocPostProcessingRequested,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const log = logger('session');
|
|
||||||
|
|
||||||
const { imageDTO } = action.payload;
|
const { imageDTO } = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
@@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
|||||||
|
|
||||||
const enqueueResult = await req.unwrap();
|
const enqueueResult = await req.unwrap();
|
||||||
req.reset();
|
req.reset();
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
|
||||||
|
|
||||||
if (error instanceof Object && 'status' in error && error.status === 403) {
|
if (error instanceof Object && 'status' in error && error.status === 403) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
*/
|
*/
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: matchAnyBoardDeleted,
|
matcher: matchAnyBoardDeleted,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const deletedBoardId = action.meta.arg.originalArgs;
|
const deletedBoardId = action.meta.arg.originalArgs;
|
||||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||||
@@ -44,7 +44,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
|
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { shouldShowArchivedBoards } = state.gallery;
|
const { shouldShowArchivedBoards } = state.gallery;
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
|
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: shouldShowArchivedBoardsChanged,
|
actionCreator: shouldShowArchivedBoardsChanged,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const shouldShowArchivedBoards = action.payload;
|
const shouldShowArchivedBoards = action.payload;
|
||||||
|
|
||||||
// We only need to take action if we have just hidden archived boards.
|
// We only need to take action if we have just hidden archived boards.
|
||||||
@@ -100,7 +100,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
*/
|
*/
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
|
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const boards = action.payload;
|
const boards = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { selectedBoardId, autoAddBoardId } = state.gallery;
|
const { selectedBoardId, autoAddBoardId } = state.gallery;
|
||||||
|
|||||||
@@ -1,33 +1,34 @@
|
|||||||
import { isAnyOf } from '@reduxjs/toolkit';
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import {
|
import { canvasReset, rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||||
canvasBatchIdsReset,
|
import { stagingAreaImageAccepted, stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||||
commitStagingAreaImage,
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
discardStagedImages,
|
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||||
resetCanvas,
|
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||||
setInitialCanvasImage,
|
|
||||||
} from 'features/canvas/store/canvasSlice';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
|
const log = logger('canvas');
|
||||||
|
|
||||||
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
|
const matchCanvasOrStagingAreaRest = isAnyOf(stagingAreaReset, canvasReset);
|
||||||
|
|
||||||
|
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher,
|
matcher: matchCanvasOrStagingAreaRest,
|
||||||
effect: async (_, { dispatch, getState }) => {
|
effect: async (_, { dispatch }) => {
|
||||||
const log = logger('canvas');
|
|
||||||
const state = getState();
|
|
||||||
const { batchIds } = state.canvas;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const req = dispatch(
|
const req = dispatch(
|
||||||
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
|
queueApi.endpoints.cancelByBatchDestination.initiate(
|
||||||
|
{ destination: 'canvas' },
|
||||||
|
{ fixedCacheKey: 'cancelByBatchOrigin' }
|
||||||
|
)
|
||||||
);
|
);
|
||||||
const { canceled } = await req.unwrap();
|
const { canceled } = await req.unwrap();
|
||||||
req.reset();
|
req.reset();
|
||||||
|
|
||||||
if (canceled > 0) {
|
if (canceled > 0) {
|
||||||
log.debug(`Canceled ${canceled} canvas batches`);
|
log.debug(`Canceled ${canceled} canvas batches`);
|
||||||
toast({
|
toast({
|
||||||
@@ -36,7 +37,6 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
|
|||||||
status: 'success',
|
status: 'success',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
dispatch(canvasBatchIdsReset());
|
|
||||||
} catch {
|
} catch {
|
||||||
log.error('Failed to cancel canvas batches');
|
log.error('Failed to cancel canvas batches');
|
||||||
toast({
|
toast({
|
||||||
@@ -47,4 +47,26 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: stagingAreaImageAccepted,
|
||||||
|
effect: (action, api) => {
|
||||||
|
const { index } = action.payload;
|
||||||
|
const state = api.getState();
|
||||||
|
const stagingAreaImage = state.canvasStagingArea.stagedImages[index];
|
||||||
|
|
||||||
|
assert(stagingAreaImage, 'No staged image found to accept');
|
||||||
|
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||||
|
|
||||||
|
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
|
||||||
|
const imageObject = imageDTOToImageObject(imageDTO);
|
||||||
|
const overrides: Partial<CanvasRasterLayerState> = {
|
||||||
|
position: { x: x + offsetX, y: y + offsetY },
|
||||||
|
objects: [imageObject],
|
||||||
|
};
|
||||||
|
|
||||||
|
api.dispatch(rasterLayerAdded({ overrides, isSelected: false }));
|
||||||
|
api.dispatch(stagingAreaReset());
|
||||||
|
},
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
|||||||
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
|
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
effect: async (_, { dispatch, getState }) => {
|
effect: (_, { dispatch, getState }) => {
|
||||||
const { data } = selectQueueStatus(getState());
|
const { data } = selectQueueStatus(getState());
|
||||||
|
|
||||||
if (!data || data.processor.is_started) {
|
if (!data || data.processor.is_started) {
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { setInfillMethod } from 'features/parameters/store/generationSlice';
|
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
|
||||||
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
|
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
|
|
||||||
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
|
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
|
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
|
||||||
const infillMethod = getState().generation.infillMethod;
|
const infillMethod = getState().params.infillMethod;
|
||||||
|
|
||||||
if (!infill_methods.includes(infillMethod)) {
|
if (!infill_methods.includes(infillMethod)) {
|
||||||
// if there is no infill method, set it to the first one
|
// if there is no infill method, set it to the first one
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ export const appStarted = createAction('app/appStarted');
|
|||||||
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: appStarted,
|
actionCreator: appStarted,
|
||||||
effect: async (action, { unsubscribe, cancelActiveListeners }) => {
|
effect: (action, { unsubscribe, cancelActiveListeners }) => {
|
||||||
// this should only run once
|
// this should only run once
|
||||||
cancelActiveListeners();
|
cancelActiveListeners();
|
||||||
unsubscribe();
|
unsubscribe();
|
||||||
|
|||||||
@@ -1,27 +1,30 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { parseify } from 'common/util/serialize';
|
import type { SerializableObject } from 'common/types';
|
||||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { truncate, upperFirst } from 'lodash-es';
|
import { truncate, upperFirst } from 'lodash-es';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
|
||||||
|
const log = logger('queue');
|
||||||
|
|
||||||
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
|
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||||
// success
|
// success
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
effect: async (action) => {
|
effect: (action) => {
|
||||||
const response = action.payload;
|
const enqueueResult = action.payload;
|
||||||
const arg = action.meta.arg.originalArgs;
|
const arg = action.meta.arg.originalArgs;
|
||||||
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
|
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
|
||||||
|
|
||||||
toast({
|
toast({
|
||||||
id: 'QUEUE_BATCH_SUCCEEDED',
|
id: 'QUEUE_BATCH_SUCCEEDED',
|
||||||
title: t('queue.batchQueued'),
|
title: t('queue.batchQueued'),
|
||||||
status: 'success',
|
status: 'success',
|
||||||
description: t('queue.batchQueuedDesc', {
|
description: t('queue.batchQueuedDesc', {
|
||||||
count: response.enqueued,
|
count: enqueueResult.enqueued,
|
||||||
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
@@ -31,9 +34,9 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
|||||||
// error
|
// error
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
||||||
effect: async (action) => {
|
effect: (action) => {
|
||||||
const response = action.payload;
|
const response = action.payload;
|
||||||
const arg = action.meta.arg.originalArgs;
|
const batchConfig = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
if (!response) {
|
if (!response) {
|
||||||
toast({
|
toast({
|
||||||
@@ -42,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
|||||||
status: 'error',
|
status: 'error',
|
||||||
description: t('common.unknownError'),
|
description: t('common.unknownError'),
|
||||||
});
|
});
|
||||||
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
|
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
|||||||
description: t('common.unknownError'),
|
description: t('common.unknownError'),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
|
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,47 +1,31 @@
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { allLayersDeleted } from 'features/controlLayers/store/controlLayersSlice';
|
|
||||||
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
|
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { deleted_images } = action.payload;
|
const { deleted_images } = action.payload;
|
||||||
|
|
||||||
// Remove all deleted images from the UI
|
// Remove all deleted images from the UI
|
||||||
|
|
||||||
let wasCanvasReset = false;
|
|
||||||
let wasNodeEditorReset = false;
|
let wasNodeEditorReset = false;
|
||||||
let wereControlAdaptersReset = false;
|
|
||||||
let wereControlLayersReset = false;
|
|
||||||
|
|
||||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
const state = getState();
|
||||||
|
const nodes = selectNodesSlice(state);
|
||||||
|
const canvas = selectCanvasSlice(state);
|
||||||
|
|
||||||
deleted_images.forEach((image_name) => {
|
deleted_images.forEach((image_name) => {
|
||||||
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
const imageUsage = getImageUsage(nodes, canvas, image_name);
|
||||||
|
|
||||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
|
||||||
dispatch(resetCanvas());
|
|
||||||
wasCanvasReset = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
||||||
dispatch(nodeEditorReset());
|
dispatch(nodeEditorReset());
|
||||||
wasNodeEditorReset = true;
|
wasNodeEditorReset = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
|
|
||||||
dispatch(controlAdaptersReset());
|
|
||||||
wereControlAdaptersReset = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (imageUsage.isControlLayerImage && !wereControlLayersReset) {
|
|
||||||
dispatch(allLayersDeleted());
|
|
||||||
wereControlLayersReset = true;
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,21 +1,15 @@
|
|||||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import {
|
|
||||||
socketBulkDownloadComplete,
|
|
||||||
socketBulkDownloadError,
|
|
||||||
socketBulkDownloadStarted,
|
|
||||||
} from 'services/events/actions';
|
|
||||||
|
|
||||||
const log = logger('images');
|
const log = logger('gallery');
|
||||||
|
|
||||||
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
|
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
|
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
|
||||||
effect: async (action) => {
|
effect: (action) => {
|
||||||
log.debug(action.payload, 'Bulk download requested');
|
log.debug(action.payload, 'Bulk download requested');
|
||||||
|
|
||||||
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
|
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
|
||||||
@@ -33,7 +27,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
|
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
|
||||||
effect: async () => {
|
effect: () => {
|
||||||
log.debug('Bulk download request failed');
|
log.debug('Bulk download request failed');
|
||||||
|
|
||||||
// There isn't any toast to update if we get this event.
|
// There isn't any toast to update if we get this event.
|
||||||
@@ -44,55 +38,4 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: socketBulkDownloadStarted,
|
|
||||||
effect: async (action) => {
|
|
||||||
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
|
|
||||||
log.debug(action.payload.data, 'Bulk download preparation started');
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: socketBulkDownloadComplete,
|
|
||||||
effect: async (action) => {
|
|
||||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
|
||||||
|
|
||||||
const { bulk_download_item_name } = action.payload.data;
|
|
||||||
|
|
||||||
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
|
|
||||||
const url = `/api/v1/images/download/${bulk_download_item_name}`;
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: bulk_download_item_name,
|
|
||||||
title: t('gallery.bulkDownloadReady', 'Download ready'),
|
|
||||||
status: 'success',
|
|
||||||
description: (
|
|
||||||
<ExternalLink
|
|
||||||
label={t('gallery.clickToDownload', 'Click here to download')}
|
|
||||||
href={url}
|
|
||||||
download={bulk_download_item_name}
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
duration: null,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: socketBulkDownloadError,
|
|
||||||
effect: async (action) => {
|
|
||||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
|
||||||
|
|
||||||
const { bulk_download_item_name } = action.payload.data;
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: bulk_download_item_name,
|
|
||||||
title: t('gallery.bulkDownloadFailed'),
|
|
||||||
status: 'error',
|
|
||||||
description: action.payload.data.error,
|
|
||||||
duration: null,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
|
||||||
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
|
||||||
|
* cancellations:
|
||||||
|
* - In the route handlers above, we track and update the cancellations object
|
||||||
|
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
|
||||||
|
* - When we get a progress event, we should check if the event is cancelled before setting the event
|
||||||
|
*
|
||||||
|
* We have a few ways that cancellations are effected, so we need to track them all:
|
||||||
|
* - by queue item id (in this case, we will compare the session_id and not the item_id)
|
||||||
|
* - by batch id
|
||||||
|
* - by destination
|
||||||
|
* - by clearing the queue
|
||||||
|
*/
|
||||||
|
type Cancellations = {
|
||||||
|
sessionIds: Set<string>;
|
||||||
|
batchIds: Set<string>;
|
||||||
|
destinations: Set<string>;
|
||||||
|
clearQueue: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const resetCancellations = (): void => {
|
||||||
|
cancellations.clearQueue = false;
|
||||||
|
cancellations.sessionIds.clear();
|
||||||
|
cancellations.batchIds.clear();
|
||||||
|
cancellations.destinations.clear();
|
||||||
|
};
|
||||||
|
|
||||||
|
const cancellations: Cancellations = {
|
||||||
|
sessionIds: new Set(),
|
||||||
|
batchIds: new Set(),
|
||||||
|
destinations: new Set(),
|
||||||
|
clearQueue: false,
|
||||||
|
} as Readonly<Cancellations>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if an item is cancelled, used to prevent race conditions with event handling.
|
||||||
|
*
|
||||||
|
* To use this, provide the session_id, batch_id and destination from the event payload.
|
||||||
|
*/
|
||||||
|
export const getIsCancelled = (item: {
|
||||||
|
session_id: string;
|
||||||
|
batch_id: string;
|
||||||
|
destination?: string | null;
|
||||||
|
}): boolean => {
|
||||||
|
if (cancellations.clearQueue) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cancellations.sessionIds.has(item.session_id)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cancellations.batchIds.has(item.batch_id)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (item.destination && cancellations.destinations.has(item.destination)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
|
||||||
|
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
|
||||||
|
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
|
effect: () => {
|
||||||
|
resetCancellations();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
|
||||||
|
effect: (action) => {
|
||||||
|
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
|
||||||
|
|
||||||
|
const event = $lastCanvasProgressEvent.get();
|
||||||
|
if (!event) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { session_id, batch_id, destination } = event;
|
||||||
|
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||||
|
$lastCanvasProgressEvent.set(null);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||||
|
effect: (action) => {
|
||||||
|
cancellations.sessionIds.add(action.payload.session_id);
|
||||||
|
|
||||||
|
const event = $lastCanvasProgressEvent.get();
|
||||||
|
if (!event) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { session_id, batch_id, destination } = event;
|
||||||
|
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||||
|
$lastCanvasProgressEvent.set(null);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
|
||||||
|
effect: (action) => {
|
||||||
|
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
|
||||||
|
cancellations.batchIds.add(batch_id);
|
||||||
|
}
|
||||||
|
const event = $lastCanvasProgressEvent.get();
|
||||||
|
if (!event) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { session_id, batch_id, destination } = event;
|
||||||
|
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||||
|
$lastCanvasProgressEvent.set(null);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
|
||||||
|
effect: () => {
|
||||||
|
cancellations.clearQueue = true;
|
||||||
|
const event = $lastCanvasProgressEvent.get();
|
||||||
|
if (!event) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { session_id, batch_id, destination } = event;
|
||||||
|
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||||
|
$lastCanvasProgressEvent.set(null);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
import { $logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
|
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
|
|
||||||
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasCopiedToClipboard,
|
|
||||||
effect: async (action, { getState }) => {
|
|
||||||
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
try {
|
|
||||||
const blob = getBaseLayerBlob(state);
|
|
||||||
|
|
||||||
copyBlobToClipboard(blob);
|
|
||||||
} catch (err) {
|
|
||||||
moduleLog.error(String(err));
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_COPY_FAILED',
|
|
||||||
title: t('toast.problemCopyingCanvas'),
|
|
||||||
description: t('toast.problemCopyingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_COPY_SUCCEEDED',
|
|
||||||
title: t('toast.canvasCopiedClipboard'),
|
|
||||||
status: 'success',
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
import { $logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
|
|
||||||
import { downloadBlob } from 'features/canvas/util/downloadBlob';
|
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
|
|
||||||
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasDownloadedAsImage,
|
|
||||||
effect: async (action, { getState }) => {
|
|
||||||
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
let blob;
|
|
||||||
try {
|
|
||||||
blob = await getBaseLayerBlob(state);
|
|
||||||
} catch (err) {
|
|
||||||
moduleLog.error(String(err));
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_DOWNLOAD_FAILED',
|
|
||||||
title: t('toast.problemDownloadingCanvas'),
|
|
||||||
description: t('toast.problemDownloadingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
downloadBlob(blob, 'canvas.png');
|
|
||||||
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user