mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
627 Commits
v6.0.0a4
...
psychedeli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1a4376b75 | ||
|
|
ef4d5d7377 | ||
|
|
6b0dfd8427 | ||
|
|
471c010217 | ||
|
|
b1193022f7 | ||
|
|
2152ca092c | ||
|
|
ccc62ba56d | ||
|
|
9cf82de8c5 | ||
|
|
aced349152 | ||
|
|
0d67ee6548 | ||
|
|
03c21d1607 | ||
|
|
752e8db1f5 | ||
|
|
85fc861dd9 | ||
|
|
458cbfd874 | ||
|
|
04331c070a | ||
|
|
632ddf0cb4 | ||
|
|
2b193ff416 | ||
|
|
96ee394f9e | ||
|
|
0badc80c0c | ||
|
|
78e6cbf96e | ||
|
|
0b969a661b | ||
|
|
6fe47ec9f8 | ||
|
|
3850dd61f8 | ||
|
|
75520eaf0f | ||
|
|
10e88c58c1 | ||
|
|
30ed4dbd92 | ||
|
|
ed9c090f33 | ||
|
|
d29f65ed22 | ||
|
|
2062ec8ac0 | ||
|
|
49e818338a | ||
|
|
1caab2b9c4 | ||
|
|
50079ea349 | ||
|
|
fffa1b24c4 | ||
|
|
a6d6170387 | ||
|
|
e5fceb0448 | ||
|
|
059baf5b29 | ||
|
|
1be8a9a310 | ||
|
|
7adc33e04d | ||
|
|
7f2dd22d47 | ||
|
|
bb50f4b8a2 | ||
|
|
a48958e0d4 | ||
|
|
e3a1e9af53 | ||
|
|
c6fe11c42f | ||
|
|
4eb1bd67df | ||
|
|
c376f914d2 | ||
|
|
b5d1c47ef7 | ||
|
|
004a52ca65 | ||
|
|
b1d5a51ddf | ||
|
|
2b2498eaa1 | ||
|
|
10dda4440e | ||
|
|
98f78abefa | ||
|
|
cc93fa270f | ||
|
|
014b27680f | ||
|
|
c3d8f875de | ||
|
|
79f9dc6e4a | ||
|
|
6e1c0c1105 | ||
|
|
0362524040 | ||
|
|
dc6656459b | ||
|
|
3ea1b97f6f | ||
|
|
a7c7405ccc | ||
|
|
c391f1117a | ||
|
|
b1e2cb8401 | ||
|
|
db6af134b7 | ||
|
|
7e6cffb00c | ||
|
|
5b187bcb00 | ||
|
|
0843d609a3 | ||
|
|
95bd9cef18 | ||
|
|
931d6521f6 | ||
|
|
e37665ff59 | ||
|
|
56857fbbe6 | ||
|
|
43cfb8a574 | ||
|
|
05b1682d15 | ||
|
|
69a08ee7f2 | ||
|
|
18212c7d8a | ||
|
|
7de26f8e69 | ||
|
|
0652b12a6f | ||
|
|
43a361a00f | ||
|
|
cf68ad9cbc | ||
|
|
ec02a39325 | ||
|
|
e52d7a05c2 | ||
|
|
c9d4e2b761 | ||
|
|
ac26aa9508 | ||
|
|
9ff6ada15b | ||
|
|
e81a115169 | ||
|
|
52827807de | ||
|
|
b631de4cb5 | ||
|
|
099ebdbc37 | ||
|
|
4de6549be9 | ||
|
|
368be34949 | ||
|
|
5baa4bd916 | ||
|
|
4229377532 | ||
|
|
2610772ffd | ||
|
|
193de6a8f2 | ||
|
|
7ea343c787 | ||
|
|
12179dabba | ||
|
|
ef135f9923 | ||
|
|
e6c67cc00f | ||
|
|
179b988148 | ||
|
|
d913a3c85b | ||
|
|
e79525c40c | ||
|
|
f409f913ac | ||
|
|
7a79f61d4c | ||
|
|
ea182c234b | ||
|
|
f2eee4a82d | ||
|
|
e129525306 | ||
|
|
ecedfce758 | ||
|
|
702cb2cb1e | ||
|
|
2e8db3cce3 | ||
|
|
7845623fa5 | ||
|
|
e6a25ca7a2 | ||
|
|
71e12bcebe | ||
|
|
863c7eb9e2 | ||
|
|
9945c20d02 | ||
|
|
e3c1334b1f | ||
|
|
c143f63ef0 | ||
|
|
067026a0d0 | ||
|
|
66991334fc | ||
|
|
b771c3b164 | ||
|
|
4925694dc1 | ||
|
|
0a737ced44 | ||
|
|
8d83caaae0 | ||
|
|
16c8017f1a | ||
|
|
61a35f1396 | ||
|
|
6bd004d868 | ||
|
|
b6a6d406c7 | ||
|
|
8e287c32ee | ||
|
|
2d8b5e26c2 | ||
|
|
50914b74ee | ||
|
|
0fc1c33536 | ||
|
|
3b08c35f72 | ||
|
|
607b2561fd | ||
|
|
d68f922efb | ||
|
|
2bbd74d418 | ||
|
|
3a5392a9ee | ||
|
|
6f80efe71d | ||
|
|
7fac833813 | ||
|
|
b67eb4134d | ||
|
|
522eeda2e2 | ||
|
|
76233241f0 | ||
|
|
54be9989c5 | ||
|
|
0d3af08d27 | ||
|
|
767ac91f2c | ||
|
|
68571ece8f | ||
|
|
01100a2b9a | ||
|
|
ce2e6d8ab6 | ||
|
|
4887424ca3 | ||
|
|
28f6a20e71 | ||
|
|
c4142e75b2 | ||
|
|
fefe563127 | ||
|
|
1c72f1ff9f | ||
|
|
605cc7369d | ||
|
|
e7ce08cffa | ||
|
|
983cb5ebd2 | ||
|
|
52dbdb7118 | ||
|
|
71e6f00e10 | ||
|
|
e73150c3e6 | ||
|
|
f2426c3ab2 | ||
|
|
9d9c4c0f1a | ||
|
|
acb930f6b9 | ||
|
|
585b54dc7d | ||
|
|
f65affc0ec | ||
|
|
22d574c92a | ||
|
|
f23be119fc | ||
|
|
2d06949e80 | ||
|
|
67804313e1 | ||
|
|
dc23be117a | ||
|
|
350de058fc | ||
|
|
fd5cd707a3 | ||
|
|
98ecefdce0 | ||
|
|
42688a0993 | ||
|
|
d94aa4abf7 | ||
|
|
69a56aafed | ||
|
|
56873f6936 | ||
|
|
6bc6a680cf | ||
|
|
9a49682f60 | ||
|
|
ff84b0a495 | ||
|
|
bcced8a5e8 | ||
|
|
4a18e9eaea | ||
|
|
dde5bf61be | ||
|
|
987e401709 | ||
|
|
5c5ac570e3 | ||
|
|
309903fe0f | ||
|
|
f16ea43e9a | ||
|
|
d794aedb43 | ||
|
|
9930440f33 | ||
|
|
f0a6c4aa1f | ||
|
|
f36d22f13c | ||
|
|
e0d7fab524 | ||
|
|
f20c230f4a | ||
|
|
05c9bc730e | ||
|
|
f17ac06591 | ||
|
|
b35f93d919 | ||
|
|
289d8076d8 | ||
|
|
604763d20f | ||
|
|
7b452f098d | ||
|
|
b41c18d35f | ||
|
|
8328081333 | ||
|
|
07517cf2c2 | ||
|
|
6b98ad9095 | ||
|
|
0de3967e7e | ||
|
|
1335377fb1 | ||
|
|
adbcc191d9 | ||
|
|
11fc7af1c8 | ||
|
|
6f12fd22b9 | ||
|
|
324b6e2af4 | ||
|
|
038010a1ca | ||
|
|
2dd1bc54c9 | ||
|
|
8b69842678 | ||
|
|
9821f7c4fc | ||
|
|
2290ff4ad6 | ||
|
|
8d82ad6d0b | ||
|
|
8ed9f652e8 | ||
|
|
ee8ed344bd | ||
|
|
6d16cfdbe2 | ||
|
|
3ef2872dda | ||
|
|
b52ba149b4 | ||
|
|
c6126c6875 | ||
|
|
3f78ac9295 | ||
|
|
79fea1ac40 | ||
|
|
6eade5781d | ||
|
|
3d8f865fb0 | ||
|
|
dc9cd22d9d | ||
|
|
fe115ff8f9 | ||
|
|
1d35aad213 | ||
|
|
195d6ce893 | ||
|
|
f13ced7ed4 | ||
|
|
735fc276e5 | ||
|
|
cd3caf8c30 | ||
|
|
e9012280ab | ||
|
|
fa72a97794 | ||
|
|
e817631ba3 | ||
|
|
d0619c033f | ||
|
|
6f4850f34f | ||
|
|
072cd9dee7 | ||
|
|
19b6dc1c1f | ||
|
|
7566d0d6c6 | ||
|
|
f123888b46 | ||
|
|
aeab7d0cab | ||
|
|
3f1b2c39ab | ||
|
|
72e3a4b4be | ||
|
|
58e0f80138 | ||
|
|
8b8e29d22d | ||
|
|
90201be670 | ||
|
|
46a5619100 | ||
|
|
d608a7469e | ||
|
|
a7d413d372 | ||
|
|
f5c9e68dbf | ||
|
|
1ded459f03 | ||
|
|
d9024dc230 | ||
|
|
40528692c3 | ||
|
|
f35b05be43 | ||
|
|
29e87fc615 | ||
|
|
ca26b2718e | ||
|
|
5fa6c0b413 | ||
|
|
c37c8c50cd | ||
|
|
f0a4de245d | ||
|
|
5db62f8643 | ||
|
|
e1c478f94c | ||
|
|
11fe3b6332 | ||
|
|
e4aae1a591 | ||
|
|
4d83d1c56d | ||
|
|
34def323e8 | ||
|
|
854956316b | ||
|
|
91afe7884a | ||
|
|
8417ee8a7b | ||
|
|
a035645ed3 | ||
|
|
e00ccba7d3 | ||
|
|
fb883d63aa | ||
|
|
b113c57fc4 | ||
|
|
7636007349 | ||
|
|
fda86ae981 | ||
|
|
c02be4bdf4 | ||
|
|
ed7772d993 | ||
|
|
baae998b5b | ||
|
|
4077ffe595 | ||
|
|
c1937b1379 | ||
|
|
5c66dfed8e | ||
|
|
126dcc96c0 | ||
|
|
cb9c7b4a28 | ||
|
|
e8c4f49a14 | ||
|
|
30fffae637 | ||
|
|
4558a292b6 | ||
|
|
825d17441c | ||
|
|
9b16504af9 | ||
|
|
46c92fadff | ||
|
|
c0467b82ac | ||
|
|
6dafa67286 | ||
|
|
eb406aa07e | ||
|
|
d9422ffebd | ||
|
|
d5c033be4d | ||
|
|
4662cd6f15 | ||
|
|
a740a22613 | ||
|
|
bf4016b4bc | ||
|
|
6fa7c8c2ee | ||
|
|
ea40f582da | ||
|
|
01caf56251 | ||
|
|
42d577e65a | ||
|
|
38d80c9ce5 | ||
|
|
6acaa8abbf | ||
|
|
4b84e34599 | ||
|
|
bbd21b1eb2 | ||
|
|
4fa83a6228 | ||
|
|
051876dcff | ||
|
|
8dc6d0b5ae | ||
|
|
40e9624954 | ||
|
|
ae27c83dc4 | ||
|
|
161059551b | ||
|
|
c196f8a5d5 | ||
|
|
2c6d22664e | ||
|
|
b9ce5389ef | ||
|
|
d1cbf56695 | ||
|
|
e379ac12c3 | ||
|
|
aa10373292 | ||
|
|
780f3692a0 | ||
|
|
3604dcfdd1 | ||
|
|
2b1cffde5e | ||
|
|
83d642ed15 | ||
|
|
455c73235e | ||
|
|
8efef8da41 | ||
|
|
060a9e57b9 | ||
|
|
099d75ca1e | ||
|
|
bbb5d68146 | ||
|
|
9066dc1839 | ||
|
|
075345bffd | ||
|
|
74d1239c87 | ||
|
|
51e1c56636 | ||
|
|
ca1df60e54 | ||
|
|
7549c1250d | ||
|
|
df8751b5a1 | ||
|
|
651b80b997 | ||
|
|
5d236ae4e7 | ||
|
|
e5dc606f5e | ||
|
|
dc6b8e13bd | ||
|
|
c1b34e1f11 | ||
|
|
89f1684072 | ||
|
|
14fbee17a3 | ||
|
|
5dbc32e06e | ||
|
|
23baf61e51 | ||
|
|
5e55f6074b | ||
|
|
f7c555e501 | ||
|
|
6aa605e811 | ||
|
|
f51014e108 | ||
|
|
9862ba9210 | ||
|
|
920aea08cc | ||
|
|
39e584297e | ||
|
|
62a14bb935 | ||
|
|
d7ae2cdf75 | ||
|
|
6172c859ac | ||
|
|
b26fb1f617 | ||
|
|
05167dfd7a | ||
|
|
c090ea7387 | ||
|
|
7ba6c67049 | ||
|
|
3de186061d | ||
|
|
a716381733 | ||
|
|
fb5df06835 | ||
|
|
33c597c224 | ||
|
|
19d882d038 | ||
|
|
ee4bc49bd4 | ||
|
|
188cf37f48 | ||
|
|
15a0a7134c | ||
|
|
22cea0de8b | ||
|
|
cd21816d12 | ||
|
|
605b912ba4 | ||
|
|
52e31112f9 | ||
|
|
a4c9346cd7 | ||
|
|
a1647e4c6e | ||
|
|
8c9ca088a7 | ||
|
|
7a7a2e147c | ||
|
|
adf4cc750a | ||
|
|
9f1ea9d1c7 | ||
|
|
571d286506 | ||
|
|
1320a2c5f8 | ||
|
|
26a9b3131d | ||
|
|
d48140b35d | ||
|
|
9757bb0325 | ||
|
|
38ccd8e09c | ||
|
|
7759b166a9 | ||
|
|
9fc51c7a6e | ||
|
|
62fa4f42f5 | ||
|
|
418ad0de38 | ||
|
|
f4a411326e | ||
|
|
6358f39ebb | ||
|
|
ea8da0bfbf | ||
|
|
5385282325 | ||
|
|
0bf84ab803 | ||
|
|
82f31f2258 | ||
|
|
966dd8857d | ||
|
|
1c778bd719 | ||
|
|
394a14cf61 | ||
|
|
0e843823d1 | ||
|
|
29462e62d2 | ||
|
|
175c0147f8 | ||
|
|
df6e67c982 | ||
|
|
4612f0ac50 | ||
|
|
386a932f2a | ||
|
|
32438532b0 | ||
|
|
ab5cb2c264 | ||
|
|
504daa0ae5 | ||
|
|
14f7c98e8a | ||
|
|
ab39305223 | ||
|
|
7948bca864 | ||
|
|
1a39d22b6c | ||
|
|
9424271d12 | ||
|
|
b5acc204a8 | ||
|
|
7aefa8f36b | ||
|
|
242da9e888 | ||
|
|
1aedc26041 | ||
|
|
2c7fa90892 | ||
|
|
6c8cf99ad2 | ||
|
|
a92ba2542c | ||
|
|
2367b9f945 | ||
|
|
a928ed0204 | ||
|
|
e164451dfe | ||
|
|
d74d079356 | ||
|
|
0eb4360c01 | ||
|
|
937c03f2ec | ||
|
|
f7b249252d | ||
|
|
b2b42be51c | ||
|
|
98368b0665 | ||
|
|
b5eb3d9798 | ||
|
|
1218f49e20 | ||
|
|
89c609fd61 | ||
|
|
b204fb6a91 | ||
|
|
6e3e316416 | ||
|
|
bf5fc9512d | ||
|
|
7080889ed4 | ||
|
|
adea983bfc | ||
|
|
f68d8ed36a | ||
|
|
d45197e0af | ||
|
|
434d8a2b12 | ||
|
|
f55c593705 | ||
|
|
8327d86774 | ||
|
|
c8254710e6 | ||
|
|
0a8f647260 | ||
|
|
32a5e9652a | ||
|
|
87909a06a8 | ||
|
|
2c8ce6f2f4 | ||
|
|
bee4cf41b4 | ||
|
|
049a8d8144 | ||
|
|
ac81ec41c3 | ||
|
|
a294e8e0fd | ||
|
|
4665f0df40 | ||
|
|
70382294f5 | ||
|
|
4028cadfaf | ||
|
|
d23cdfd0ad | ||
|
|
f0ba693922 | ||
|
|
214005d795 | ||
|
|
34aa131115 | ||
|
|
5d8061bea9 | ||
|
|
36ec1015d6 | ||
|
|
7208373576 | ||
|
|
e10afe3026 | ||
|
|
399d6e7bce | ||
|
|
8d0fe5522b | ||
|
|
81341deb46 | ||
|
|
a30933b09c | ||
|
|
3264188ffd | ||
|
|
3984b341e1 | ||
|
|
041023df53 | ||
|
|
b06f76cdb6 | ||
|
|
852badc90b | ||
|
|
01953cf057 | ||
|
|
241844bdef | ||
|
|
33a28ad4f9 | ||
|
|
7c4550cbd5 | ||
|
|
553d1a6ac6 | ||
|
|
f4794e409b | ||
|
|
df87800d61 | ||
|
|
16993cd216 | ||
|
|
7f222ffb9d | ||
|
|
e0ed56ff8d | ||
|
|
e7e1142c77 | ||
|
|
fcaeba290e | ||
|
|
6eecdca56c | ||
|
|
7f44da4902 | ||
|
|
abaa33e22c | ||
|
|
d5c238e7c2 | ||
|
|
18775e8b67 | ||
|
|
903776bfbc | ||
|
|
a5baf0c102 | ||
|
|
a7e45731ec | ||
|
|
32aa3e6d48 | ||
|
|
2f9ea91896 | ||
|
|
5ac5115269 | ||
|
|
161624c722 | ||
|
|
c31cb0b106 | ||
|
|
893f7a8744 | ||
|
|
2e0824a799 | ||
|
|
ed05bf2df3 | ||
|
|
0f1a69a0c3 | ||
|
|
450a0bf142 | ||
|
|
a28c15d545 | ||
|
|
1b1e1983d9 | ||
|
|
d08e2fbd82 | ||
|
|
45b1ef6231 | ||
|
|
3bb446c08f | ||
|
|
8d1ab0a2e5 | ||
|
|
48e2e7e4a1 | ||
|
|
5a2f5c105d | ||
|
|
aa93e95a94 | ||
|
|
a5e5cbd7c3 | ||
|
|
baa9141be3 | ||
|
|
c7ed351bab | ||
|
|
8c17bde4ea | ||
|
|
ba082ccc2f | ||
|
|
01784fb3bf | ||
|
|
a71a0e143c | ||
|
|
94afc13813 | ||
|
|
d640a9001b | ||
|
|
711fe91b24 | ||
|
|
2f26657c17 | ||
|
|
6754fde935 | ||
|
|
ac206f4767 | ||
|
|
c316f07fb2 | ||
|
|
e81dde0933 | ||
|
|
9f392c8c3c | ||
|
|
2531366386 | ||
|
|
9df69496e4 | ||
|
|
2ddcde13ff | ||
|
|
cc5083599d | ||
|
|
2431060a7e | ||
|
|
592c842632 | ||
|
|
bc3550f238 | ||
|
|
23511d68db | ||
|
|
cd0668dd0b | ||
|
|
bf5ed61b84 | ||
|
|
3038a797a6 | ||
|
|
9bbc31b2d9 | ||
|
|
526e6335a1 | ||
|
|
1412c079ad | ||
|
|
6570c0c3b9 | ||
|
|
3a08ea799a | ||
|
|
e3fc244126 | ||
|
|
56938ca0a1 | ||
|
|
5d80642ea4 | ||
|
|
da4b084a8b | ||
|
|
86e1a37a00 | ||
|
|
ea34690709 | ||
|
|
c8df7cd2c0 | ||
|
|
628367b97b | ||
|
|
002816653e | ||
|
|
b05de8634d | ||
|
|
5088e700ad | ||
|
|
d2155e98ef | ||
|
|
7ec511da01 | ||
|
|
985cd8272b | ||
|
|
cd136194ad | ||
|
|
2e2ac71278 | ||
|
|
db4220fb20 | ||
|
|
84f70942e7 | ||
|
|
0af20b03e5 | ||
|
|
e16414b452 | ||
|
|
5dbc2a74a2 | ||
|
|
ad736bc190 | ||
|
|
0e9b71801a | ||
|
|
e80f0b2b43 | ||
|
|
c9042e52d4 | ||
|
|
8a78e37634 | ||
|
|
5e93f58530 | ||
|
|
a3851e0b08 | ||
|
|
eb45a457e9 | ||
|
|
1446d3490b | ||
|
|
579318af70 | ||
|
|
57bfae6774 | ||
|
|
2a92524546 | ||
|
|
7a5fa25b48 | ||
|
|
b3f3020793 | ||
|
|
650809e50d | ||
|
|
7308428f32 | ||
|
|
4dc3f1bcee | ||
|
|
faeb5f0c3b | ||
|
|
d985dfe821 | ||
|
|
ce5ae83689 | ||
|
|
c0428ee7ef | ||
|
|
aa3b2106d4 | ||
|
|
cf2d67ef3d | ||
|
|
c4d1e78f59 | ||
|
|
02e4a3aa82 | ||
|
|
a0b0c30be9 | ||
|
|
5c4cbc7fa2 | ||
|
|
5f2f12f803 | ||
|
|
c9cd0a87be | ||
|
|
668c475271 | ||
|
|
341910739e | ||
|
|
53a3dc52bc | ||
|
|
23b0a4a7f4 | ||
|
|
6afbf31750 | ||
|
|
3cd4306eec | ||
|
|
827191d2fc | ||
|
|
aaa34f717d | ||
|
|
fe83c2f81f | ||
|
|
17dead3309 | ||
|
|
979bd33dfb | ||
|
|
5128f072a8 | ||
|
|
2ad5b5cc2e | ||
|
|
24d8a96071 | ||
|
|
f1e4665aa2 | ||
|
|
1cbfea3a21 | ||
|
|
981e8e217d | ||
|
|
e7ca30f406 | ||
|
|
2832ca300f | ||
|
|
de5f413440 | ||
|
|
fbc14c61ea | ||
|
|
77e029a49f | ||
|
|
61b049ad35 | ||
|
|
b88f4a24d0 | ||
|
|
8c632f0d32 | ||
|
|
150a876c73 | ||
|
|
62c3b01e4f | ||
|
|
e1157f343b | ||
|
|
6a78739076 | ||
|
|
0794eb43e7 | ||
|
|
4ee54eac1d | ||
|
|
5851c46c81 | ||
|
|
a296559e79 | ||
|
|
1fd83f5e68 | ||
|
|
637487c573 | ||
|
|
4e98e7d0a2 | ||
|
|
12f65d800d | ||
|
|
45d09f8f51 | ||
|
|
2876c72fa9 | ||
|
|
9b4fdb493e | ||
|
|
47e21d6e04 | ||
|
|
84ab4a1c30 | ||
|
|
85c4304efd | ||
|
|
8f152f162b | ||
|
|
63b49f045a |
@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 18
|
||||
- name: setup node 20
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
node-version: '20'
|
||||
|
||||
- name: setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 8.15.6
|
||||
version: 10
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,6 +180,7 @@ cython_debug/
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
.vscode/
|
||||
.zed/
|
||||
|
||||
# source installer files
|
||||
installer/*zip
|
||||
|
||||
@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
|
||||
<!-- links -->
|
||||
|
||||
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod/v4'
|
||||
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
|
||||
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
||||
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
||||
|
||||
@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download
|
||||
|
||||
Download the most launcher for your operating system:
|
||||
Download the most recent launcher for your operating system:
|
||||
|
||||
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
|
||||
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
class AddImagesToBoardResult(BaseModel):
|
||||
board_id: str = Field(description="The id of the board the images were added to")
|
||||
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(BaseModel):
|
||||
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="add_image_to_board",
|
||||
@@ -23,17 +14,26 @@ class RemoveImagesFromBoardResult(BaseModel):
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_image_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
@@ -45,14 +45,25 @@ async def add_image_to_board(
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_image_from_board(
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
):
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
@@ -72,16 +83,25 @@ async def add_images_to_board(
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
try:
|
||||
added_image_names: list[str] = []
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
board_id=board_id,
|
||||
image_name=image_name,
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
@@ -100,13 +120,20 @@ async def remove_images_from_board(
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
try:
|
||||
removed_image_names: list[str] = []
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -14,10 +14,17 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.images.images_common import (
|
||||
DeleteImagesResult,
|
||||
ImageDTO,
|
||||
ImageUrlsDTO,
|
||||
StarredImagesResult,
|
||||
UnstarredImagesResult,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
|
||||
@@ -65,7 +72,7 @@ async def upload_image(
|
||||
resize_to: Optional[str] = Body(
|
||||
default=None,
|
||||
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
|
||||
example='"[1024,1024]"',
|
||||
examples=['"[1024,1024]"'],
|
||||
),
|
||||
metadata: Optional[str] = Body(
|
||||
default=None,
|
||||
@@ -153,18 +160,30 @@ async def create_image_upload_entry(
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
) -> DeleteImagesResult:
|
||||
"""Deletes an image"""
|
||||
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except Exception:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
|
||||
@images_router.delete("/intermediates", operation_id="clear_intermediates")
|
||||
async def clear_intermediates() -> int:
|
||||
@@ -376,31 +395,32 @@ async def list_image_dtos(
|
||||
return image_dtos
|
||||
|
||||
|
||||
class DeleteImagesFromListResult(BaseModel):
|
||||
deleted_images: list[str]
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
|
||||
async def delete_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesFromListResult:
|
||||
) -> DeleteImagesResult:
|
||||
try:
|
||||
deleted_images: list[str] = []
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@images_router.delete(
|
||||
"/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesFromListResult
|
||||
)
|
||||
async def delete_uncategorized_images() -> DeleteImagesFromListResult:
|
||||
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
|
||||
async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
"""Deletes all images that are uncategorized"""
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
@@ -408,14 +428,19 @@ async def delete_uncategorized_images() -> DeleteImagesFromListResult:
|
||||
)
|
||||
|
||||
try:
|
||||
deleted_images: list[str] = []
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
@@ -424,36 +449,50 @@ class ImagesUpdatedFromListResult(BaseModel):
|
||||
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
|
||||
async def star_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
) -> StarredImagesResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
starred_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=True)
|
||||
)
|
||||
starred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
return StarredImagesResult(
|
||||
starred_images=list(starred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
|
||||
async def unstar_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
) -> UnstarredImagesResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
unstarred_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=False)
|
||||
)
|
||||
unstarred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
return UnstarredImagesResult(
|
||||
unstarred_images=list(unstarred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -524,3 +563,61 @@ async def get_bulk_download_item(
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image names")
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/images_by_names",
|
||||
operation_id="get_images_by_names",
|
||||
responses={200: {"model": list[ImageDTO]}},
|
||||
)
|
||||
async def get_images_by_names(
|
||||
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
|
||||
) -> list[ImageDTO]:
|
||||
"""Gets image DTOs for the specified image names. Maintains order of input names."""
|
||||
|
||||
try:
|
||||
image_service = ApiDependencies.invoker.services.images
|
||||
|
||||
# Fetch DTOs preserving the order of requested names
|
||||
image_dtos: list[ImageDTO] = []
|
||||
for name in image_names:
|
||||
try:
|
||||
dto = image_service.get_dto(name)
|
||||
image_dtos.append(dto)
|
||||
except Exception:
|
||||
# Skip missing images - they may have been deleted between name fetch and DTO fetch
|
||||
continue
|
||||
|
||||
return image_dtos
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image DTOs")
|
||||
|
||||
@@ -41,6 +41,7 @@ from invokeai.backend.model_manager.starter_models import (
|
||||
STARTER_BUNDLES,
|
||||
STARTER_MODELS,
|
||||
StarterModel,
|
||||
StarterModelBundle,
|
||||
StarterModelWithoutDependencies,
|
||||
)
|
||||
|
||||
@@ -291,7 +292,7 @@ async def get_hugging_face_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -449,7 +450,7 @@ async def install_model(
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
examples=[{"name": "string", "description": "string"}],
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
@@ -799,7 +800,7 @@ async def convert_model(
|
||||
|
||||
class StarterModelResponse(BaseModel):
|
||||
starter_models: list[StarterModel]
|
||||
starter_bundles: dict[str, list[StarterModel]]
|
||||
starter_bundles: dict[str, StarterModelBundle]
|
||||
|
||||
|
||||
def get_is_installed(
|
||||
@@ -833,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
for bundle in starter_bundles.values():
|
||||
for model in bundle:
|
||||
for model in bundle.models:
|
||||
model.is_installed = get_is_installed(model, installed_models)
|
||||
# Remove already-installed dependencies
|
||||
missing_deps: list[StarterModelWithoutDependencies] = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -14,13 +14,15 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
@@ -58,17 +60,19 @@ async def enqueue_batch(
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
try:
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list",
|
||||
operation_id="list_queue_items",
|
||||
responses={
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def list_queue_items(
|
||||
@@ -77,12 +81,42 @@ async def list_queue_items(
|
||||
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
|
||||
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
|
||||
priority: int = Query(default=0, description="The pagination cursor priority"),
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets all queue items (without graphs)"""
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
"""Gets cursor-paginated queue items"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
|
||||
)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id,
|
||||
limit=limit,
|
||||
status=status,
|
||||
cursor=cursor,
|
||||
priority=priority,
|
||||
destination=destination,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list_all",
|
||||
operation_id="list_all_queue_items",
|
||||
responses={
|
||||
200: {"model": list[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def list_all_queue_items(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -94,7 +128,10 @@ async def resume(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -106,7 +143,10 @@ async def Pause(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -118,7 +158,25 @@ async def cancel_all_except_current(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
"""Immediately cancels all queue items except in-processing items"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/delete_all_except_current",
|
||||
operation_id="delete_all_except_current",
|
||||
responses={200: {"model": DeleteAllExceptCurrentResult}},
|
||||
)
|
||||
async def delete_all_except_current(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
"""Immediately deletes all queue items except in-processing items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -131,7 +189,12 @@ async def cancel_by_batch_ids(
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
|
||||
queue_id=queue_id, batch_ids=batch_ids
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -144,9 +207,12 @@ async def cancel_by_destination(
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -159,7 +225,10 @@ async def retry_items_by_id(
|
||||
item_ids: list[int] = Body(description="The queue item ids to retry"),
|
||||
) -> RetryItemsResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -173,11 +242,14 @@ async def clear(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -191,7 +263,10 @@ async def prune(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -205,7 +280,10 @@ async def get_current_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -219,7 +297,10 @@ async def get_next_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -233,9 +314,12 @@ async def get_queue_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -250,7 +334,10 @@ async def get_batch_status(
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -266,7 +353,27 @@ async def get_queue_item(
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
"""Gets a queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.delete(
|
||||
"/{queue_id}/i/{item_id}",
|
||||
operation_id="delete_queue_item",
|
||||
)
|
||||
async def delete_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to delete"),
|
||||
) -> None:
|
||||
"""Deletes a queue item"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -281,8 +388,12 @@ async def cancel_queue_item(
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -295,6 +406,27 @@ async def counts_by_destination(
|
||||
destination: str = Query(description="The destination to query"),
|
||||
) -> SessionQueueCountsByDestination:
|
||||
"""Gets the counts of queue items by destination"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
|
||||
|
||||
|
||||
@session_queue_router.delete(
|
||||
"/{queue_id}/d/{destination}",
|
||||
operation_id="delete_by_destination",
|
||||
responses={200: {"model": DeleteByDestinationResult}},
|
||||
)
|
||||
async def delete_by_destination(
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Path(description="The destination to query"),
|
||||
) -> DeleteByDestinationResult:
|
||||
"""Deletes all items with the given destination"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
|
||||
|
||||
@@ -582,6 +582,8 @@ def invocation(
|
||||
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
original_model_fields: dict[str, OriginalModelField] = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
||||
@@ -589,7 +591,7 @@ def invocation(
|
||||
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
|
||||
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
|
||||
|
||||
@@ -676,6 +678,7 @@ def invocation(
|
||||
docstring = cls.__doc__
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
|
||||
new_class.__doc__ = docstring
|
||||
new_class._original_model_fields = original_model_fields
|
||||
|
||||
InvocationRegistry.register_invocation(new_class)
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
Imagen3Model = "Imagen3ModelField"
|
||||
Imagen4Model = "Imagen4ModelField"
|
||||
ChatGPT4oModel = "ChatGPT4oModelField"
|
||||
FluxKontextModel = "FluxKontextModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -214,6 +215,7 @@ class FieldDescriptions:
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
flux_fill_conditioning = "FLUX Fill conditioning tensor"
|
||||
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
@@ -290,6 +292,12 @@ class FluxFillConditioningField(BaseModel):
|
||||
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
|
||||
|
||||
|
||||
class FluxKontextConditioningField(BaseModel):
|
||||
"""A conditioning field for FLUX Kontext (reference image)."""
|
||||
|
||||
image: ImageField = Field(description="The Kontext reference image.")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -16,13 +16,12 @@ from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxFillConditioningField,
|
||||
FluxKontextConditioningField,
|
||||
FluxReduxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
@@ -34,6 +33,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.3.0",
|
||||
version="4.0.0",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
@@ -145,11 +145,20 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
|
||||
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
|
||||
# upstream nodes.
|
||||
|
||||
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
||||
)
|
||||
|
||||
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference image).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -376,6 +385,27 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning is not None:
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
|
||||
|
||||
kontext_extension = KontextExtension(
|
||||
context=context,
|
||||
kontext_conditioning=self.kontext_conditioning,
|
||||
vae_field=self.controlnet_vae,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# Prepare Kontext conditioning if provided
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
if kontext_extension is not None:
|
||||
# Ensure batch sizes match
|
||||
kontext_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
@@ -391,6 +421,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -865,7 +897,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
# The denoise function now handles Kontext conditioning correctly,
|
||||
# so we don't need to slice the latents here
|
||||
latents = state.latents.float()
|
||||
state.latents = unpack(latents, self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
|
||||
40
invokeai/app/invocations/flux_kontext.py
Normal file
40
invokeai/app/invocations/flux_kontext.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxKontextConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_kontext_output")
|
||||
class FluxKontextOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Kontext invocation."""
|
||||
|
||||
kontext_cond: FluxKontextConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_kontext",
|
||||
title="Kontext Conditioning - FLUX",
|
||||
tags=["conditioning", "kontext", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextInvocation(BaseInvocation):
|
||||
"""Prepares a reference image for FLUX Kontext conditioning."""
|
||||
|
||||
image: ImageField = InputField(description="The Kontext reference image.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
|
||||
"""Packages the provided image into a Kontext conditioning field."""
|
||||
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
@@ -111,6 +111,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_t5_tokenization(context, t5_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
@@ -151,6 +154,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_clip_tokenization(context, clip_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
@@ -170,3 +176,88 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _log_t5_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
|
||||
|
||||
# Tokenize the prompt using the same parameters as the model's text encoder.
|
||||
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=self.t5_max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True, # This is important for T5 to add the EOS token.
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("\u2581", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for token in tokens:
|
||||
if token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
used_tokens += 1
|
||||
elif token == tokenizer.pad_token:
|
||||
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
|
||||
continue
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
def _log_clip_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: CLIPTokenizer,
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a CLIP-based model."""
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
attention_mask = tokenized_output.attention_mask[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The CLIP tokenizer uses '</w>' to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("</w>", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for i, token in enumerate(tokens):
|
||||
if attention_mask[i] == 0:
|
||||
# Do not log padding tokens.
|
||||
continue
|
||||
|
||||
if token == tokenizer.bos_token:
|
||||
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
|
||||
elif token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@@ -97,3 +98,17 @@ class ImageRecordStorageBase(ABC):
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field, StrictBool, StrictStr
|
||||
from pydantic import BaseModel, Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
@@ -207,3 +207,16 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
starred=starred,
|
||||
has_workflow=has_workflow,
|
||||
)
|
||||
|
||||
|
||||
class ImageCollectionCounts(BaseModel):
|
||||
starred_count: int = Field(description="The number of starred images in the collection.")
|
||||
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
|
||||
|
||||
|
||||
class ImageNamesResult(BaseModel):
|
||||
"""Response containing ordered image names with metadata for optimistic updates."""
|
||||
|
||||
image_names: list[str] = Field(description="Ordered list of image names")
|
||||
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
|
||||
total_count: int = Field(description="Total number of images matching the query")
|
||||
|
||||
@@ -7,6 +7,7 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
IMAGE_DTO_COLS,
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@@ -386,3 +387,96 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return None
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
|
||||
|
||||
@@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@@ -125,7 +126,7 @@ class ImageServiceABC(ABC):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -147,3 +148,17 @@ class ImageServiceABC(ABC):
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
"""Deletes all images on a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecord
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
@@ -39,3 +39,27 @@ def image_record_to_dto(
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
class ResultWithAffectedBoards(BaseModel):
|
||||
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
|
||||
|
||||
|
||||
class DeleteImagesResult(ResultWithAffectedBoards):
|
||||
deleted_images: list[str] = Field(description="The names of the images that were deleted")
|
||||
|
||||
|
||||
class StarredImagesResult(ResultWithAffectedBoards):
|
||||
starred_images: list[str] = Field(description="The names of the images that were starred")
|
||||
|
||||
|
||||
class UnstarredImagesResult(ResultWithAffectedBoards):
|
||||
unstarred_images: list[str] = Field(description="The names of the images that were unstarred")
|
||||
|
||||
|
||||
class AddImagesToBoardResult(ResultWithAffectedBoards):
|
||||
added_images: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(ResultWithAffectedBoards):
|
||||
removed_images: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
@@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import (
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@@ -309,3 +310,27 @@ class ImageService(ImageServiceABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting intermediates count")
|
||||
raise e
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image names")
|
||||
raise e
|
||||
|
||||
@@ -10,6 +10,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
@@ -17,7 +19,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
@@ -92,6 +93,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_queue_item(self, item_id: int) -> None:
|
||||
"""Deletes a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fail_queue_item(
|
||||
self, item_id: int, error_type: str, error_message: str, error_traceback: str
|
||||
@@ -109,6 +115,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items with the given batch destination"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
"""Deletes all queue items with the given batch destination"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
@@ -119,6 +130,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items except in-progress items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
"""Deletes all queue items except in-progress items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_queue_items(
|
||||
self,
|
||||
@@ -127,10 +143,20 @@ class SessionQueueBase(ABC):
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
"""Gets a page of session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
|
||||
@@ -205,9 +205,10 @@ class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
user_label: str | None = Field(description="The user label of the field, if any")
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
class SessionQueueItem(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
item_id: int = Field(description="The identifier of the session queue item")
|
||||
@@ -251,42 +252,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
default=None,
|
||||
description="The ID of the published workflow associated with this queue item",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
return SessionQueueItemDTO(**queue_item_dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
"queue_id",
|
||||
"session_id",
|
||||
"priority",
|
||||
"session_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
|
||||
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
default=None, description="The workflow associated with this queue item"
|
||||
@@ -366,6 +332,7 @@ class EnqueueBatchResult(BaseModel):
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
priority: int = Field(description="The priority of the enqueued batch")
|
||||
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
|
||||
|
||||
|
||||
class RetryItemsResult(BaseModel):
|
||||
@@ -397,6 +364,18 @@ class CancelByDestinationResult(CancelByBatchIDsResult):
|
||||
pass
|
||||
|
||||
|
||||
class DeleteByDestinationResult(BaseModel):
|
||||
"""Result of deleting by a destination"""
|
||||
|
||||
deleted: int = Field(..., description="Number of queue items deleted")
|
||||
|
||||
|
||||
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
|
||||
"""Result of deleting all except current"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
@@ -24,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
@@ -46,10 +47,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
else:
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
@@ -136,6 +133,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE batch_id = ?
|
||||
ORDER BY item_id DESC;
|
||||
""",
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
except Exception:
|
||||
raise
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
@@ -144,6 +153,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
item_ids=item_ids,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
||||
return enqueue_result
|
||||
@@ -217,6 +227,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -328,6 +351,27 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
return queue_item
|
||||
|
||||
def delete_queue_item(self, item_id: int) -> None:
|
||||
"""Deletes a session queue item"""
|
||||
try:
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
return queue_item
|
||||
@@ -360,6 +404,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
cursor.execute(
|
||||
@@ -398,6 +444,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
cursor.execute(
|
||||
@@ -425,6 +473,71 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
raise
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
params = (queue_id, destination)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND destination = ?;
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND destination = ?;
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND status == 'pending'
|
||||
"""
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
@@ -435,6 +548,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id]
|
||||
cursor.execute(
|
||||
@@ -455,12 +570,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
@@ -540,26 +652,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -571,6 +669,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
@@ -586,7 +690,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
@@ -594,6 +698,37 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
cursor_ = self._conn.cursor()
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -608,7 +743,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] for row in counts_result)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
@@ -637,7 +772,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
total = sum(row[1] or 0 for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
@@ -669,7 +804,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] for row in counts_result)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
|
||||
return SessionQueueCountsByDestination(
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
ValidationError,
|
||||
@@ -57,17 +58,32 @@ class Edge(BaseModel):
|
||||
|
||||
|
||||
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_annotation())
|
||||
node_output_field = node_outputs.get(field) or None
|
||||
return node_output_field
|
||||
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
|
||||
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
|
||||
# would require some fairly significant changes and I don't want risk breaking anything.
|
||||
try:
|
||||
invocation_class = type(node)
|
||||
invocation_output_class = invocation_class.get_output_annotation()
|
||||
field_info = invocation_output_class.model_fields.get(field)
|
||||
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
|
||||
output_field_type = field_info.annotation
|
||||
return output_field_type
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_inputs = get_type_hints(node_type)
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
|
||||
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
|
||||
# would require some fairly significant changes and I don't want risk breaking anything.
|
||||
try:
|
||||
invocation_class = type(node)
|
||||
field_info = invocation_class.model_fields.get(field)
|
||||
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
|
||||
input_field_type = field_info.annotation
|
||||
return input_field_type
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
@@ -787,6 +803,22 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
"executed",
|
||||
"executed_history",
|
||||
"results",
|
||||
"errors",
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
@@ -975,10 +1007,11 @@ class GraphExecutionState(BaseModel):
|
||||
new_node_ids = []
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(
|
||||
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
|
||||
)
|
||||
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||
all_iteration_mappings = []
|
||||
for source_node_id in next_node_parents:
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
||||
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
|
||||
|
||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
|
||||
@@ -123,7 +123,11 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
|
||||
if total_steps == 0:
|
||||
return 0.0
|
||||
if order == 2:
|
||||
return floor(step / 2) / floor(total_steps / 2)
|
||||
# Prevent division by zero when total_steps is 1 or 2
|
||||
denominator = floor(total_steps / 2)
|
||||
if denominator == 0:
|
||||
return 0.0
|
||||
return floor(step / 2) / denominator
|
||||
# order == 1
|
||||
return step / total_steps
|
||||
|
||||
|
||||
@@ -30,8 +30,11 @@ def denoise(
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
# extra img tokens
|
||||
# extra img tokens (channel-wise)
|
||||
img_cond: torch.Tensor | None,
|
||||
# extra img tokens (sequence-wise) - for Kontext conditioning
|
||||
img_cond_seq: torch.Tensor | None = None,
|
||||
img_cond_seq_ids: torch.Tensor | None = None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -46,6 +49,10 @@ def denoise(
|
||||
)
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
# Store original sequence length for slicing predictions
|
||||
original_seq_len = img.shape[1]
|
||||
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
@@ -71,10 +78,26 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
|
||||
# Prepare input for model - concatenate fresh each step
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
|
||||
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
|
||||
if img_cond is not None:
|
||||
img_input = torch.cat((img_input, img_cond), dim=-1)
|
||||
|
||||
# Add sequence-wise conditioning (for Kontext)
|
||||
if img_cond_seq is not None:
|
||||
assert img_cond_seq_ids is not None, (
|
||||
"You need to provide either both or neither of the sequence conditioning"
|
||||
)
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
pred = model(
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
img=img_input,
|
||||
img_ids=img_input_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
@@ -88,6 +111,10 @@ def denoise(
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
# Slice prediction to only include the main image tokens
|
||||
if img_input_ids is not None:
|
||||
pred = pred[:, :original_seq_len]
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
|
||||
149
invokeai/backend/flux/extensions/kontext_extension.py
Normal file
149
invokeai/backend/flux/extensions/kontext_extension.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
|
||||
|
||||
def generate_img_ids_with_offset(
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
idx_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
|
||||
Args:
|
||||
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
|
||||
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
|
||||
batch_size (int): Number of images in the batch.
|
||||
device (torch.device): Device to create tensors on.
|
||||
dtype (torch.dtype): Data type for the tensors.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
|
||||
"""
|
||||
|
||||
if device.type == "mps":
|
||||
orig_dtype = dtype
|
||||
dtype = torch.float16
|
||||
|
||||
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
|
||||
# The 3 channels represent: [batch_offset, y_position, x_position]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
|
||||
|
||||
# Set the batch offset for all positions
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
# Create y-coordinate indices (vertical positions)
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
|
||||
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
|
||||
img_ids[..., 1] = y_indices[:, None]
|
||||
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
|
||||
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
|
||||
img_ids[..., 2] = x_indices[None, :]
|
||||
|
||||
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
img_ids = img_ids.to(orig_dtype)
|
||||
|
||||
return img_ids
|
||||
|
||||
|
||||
class KontextExtension:
|
||||
"""Applies FLUX Kontext (reference image) conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kontext_conditioning: FluxKontextConditioningField,
|
||||
context: InvocationContext,
|
||||
vae_field: VAEField,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Initializes the KontextExtension, pre-processing the reference image
|
||||
into latents and positional IDs.
|
||||
"""
|
||||
self._context = context
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
self._vae_field = vae_field
|
||||
self.kontext_conditioning = kontext_conditioning
|
||||
|
||||
# Pre-process and cache the kontext latents and ids upon initialization.
|
||||
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
|
||||
|
||||
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encodes the reference image and prepares its latents and IDs."""
|
||||
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
|
||||
|
||||
# Calculate aspect ratio of input image
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Find the closest preferred resolution by aspect ratio
|
||||
_, target_width, target_height = min(
|
||||
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
|
||||
)
|
||||
|
||||
# Apply BFL's scaling formula
|
||||
# This ensures compatibility with the model's training
|
||||
scaled_width = 2 * int(target_width / 16)
|
||||
scaled_height = 2 * int(target_height / 16)
|
||||
|
||||
# Resize to the exact resolution used during training
|
||||
image = image.convert("RGB")
|
||||
final_width = 8 * scaled_width
|
||||
final_height = 8 * scaled_height
|
||||
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to tensor with same normalization as BFL
|
||||
image_np = np.array(image)
|
||||
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
|
||||
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
|
||||
image_tensor = image_tensor.to(self._device)
|
||||
|
||||
# Continue with VAE encoding
|
||||
vae_info = self._context.models.load(self._vae_field.vae)
|
||||
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
# Extract tensor dimensions
|
||||
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
|
||||
|
||||
# Pack the latents and generate IDs
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
batch_size=batch_size,
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=1,
|
||||
)
|
||||
|
||||
return kontext_latents_packed, kontext_ids
|
||||
|
||||
def ensure_batch_size(self, target_batch_size: int) -> None:
|
||||
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
|
||||
if self.kontext_latents.shape[0] != target_batch_size:
|
||||
self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1)
|
||||
self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1)
|
||||
@@ -174,11 +174,13 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
|
||||
dtype = torch.float16
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
# Set batch offset to 0 for main image tokens
|
||||
img_ids[..., 0] = 0
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
img_ids.to(orig_dtype)
|
||||
img_ids = img_ids.to(orig_dtype)
|
||||
|
||||
return img_ids
|
||||
|
||||
@@ -18,6 +18,29 @@ class ModelSpec:
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
# Preferred resolutions for Kontext models to avoid tiling artifacts
|
||||
# These are the specific resolutions the model was trained on
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-dev-fill": 512,
|
||||
|
||||
@@ -37,6 +37,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
@@ -334,6 +335,36 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
|
||||
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
|
||||
metadata = mod.metadata()
|
||||
return (
|
||||
metadata.get("modelspec.sai_model_spec")
|
||||
and metadata.get("ot_branch") == "omi_format"
|
||||
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
metadata = mod.metadata()
|
||||
architecture = metadata["modelspec.architecture"]
|
||||
|
||||
if architecture == stable_diffusion_xl_1_lora:
|
||||
base = BaseModelType.StableDiffusionXL
|
||||
elif architecture == flux_dev_1_lora:
|
||||
base = BaseModelType.Flux
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
|
||||
|
||||
return {"base": base}
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@@ -350,7 +381,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if type(key) is int:
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
@@ -668,6 +699,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
|
||||
@@ -7,7 +7,14 @@ from typing import Optional
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForTextEncoding,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -139,7 +146,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
@@ -183,7 +190,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(
|
||||
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
|
||||
|
||||
@@ -13,6 +13,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.omi.omi import convert_from_omi
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
@@ -43,6 +44,8 @@ from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import l
|
||||
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
|
||||
@@ -77,12 +80,23 @@ class LoRALoader(ModelLoader):
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Strip 'bundle_emb' keys - these are unused and currently cause downstream errors.
|
||||
# To revisit later to determine if they're needed/useful.
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")}
|
||||
|
||||
# At the time of writing, we support the OMI standard for base models Flux and SDXL
|
||||
if config.format == ModelFormat.OMI and self._model_base in [
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.Flux,
|
||||
]:
|
||||
state_dict = convert_from_omi(state_dict, config.base) # type: ignore
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]:
|
||||
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
|
||||
# distributed as a single file without the associated metadata containing the alpha value. We chose
|
||||
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
|
||||
@@ -99,7 +113,7 @@ class LoRALoader(ModelLoader):
|
||||
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
raise ValueError("LoRA model is in unsupported FLUX format")
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
7
invokeai/backend/model_manager/omi/__init__.py
Normal file
7
invokeai/backend/model_manager/omi/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from invokeai.backend.model_manager.omi.omi import convert_from_omi
|
||||
from invokeai.backend.model_manager.omi.vendor.model_spec.architecture import (
|
||||
flux_dev_1_lora,
|
||||
stable_diffusion_xl_1_lora,
|
||||
)
|
||||
|
||||
__all__ = ["flux_dev_1_lora", "stable_diffusion_xl_1_lora", "convert_from_omi"]
|
||||
21
invokeai/backend/model_manager/omi/omi.py
Normal file
21
invokeai/backend/model_manager/omi/omi.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from invokeai.backend.model_manager.model_on_disk import StateDict
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
|
||||
convert_flux_lora as omi_flux,
|
||||
)
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
|
||||
convert_lora_util as lora_util,
|
||||
)
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
|
||||
convert_sdxl_lora as omi_sdxl,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
|
||||
keyset = {
|
||||
BaseModelType.Flux: omi_flux.convert_flux_lora_key_sets(),
|
||||
BaseModelType.StableDiffusionXL: omi_sdxl.convert_sdxl_lora_key_sets(),
|
||||
}[base]
|
||||
source = "omi"
|
||||
target = "legacy_diffusers"
|
||||
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore
|
||||
0
invokeai/backend/model_manager/omi/vendor/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/lora/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/lora/__init__.py
vendored
Normal file
20
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_clip.py
vendored
Normal file
20
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_clip.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
|
||||
|
||||
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("text_projection", "text_projection", parent=key_prefix)]
|
||||
|
||||
for k in map_prefix_range("text_model.encoder.layers", "text_model.encoder.layers", parent=key_prefix):
|
||||
keys += [LoraConversionKeySet("mlp.fc1", "mlp.fc1", parent=k)]
|
||||
keys += [LoraConversionKeySet("mlp.fc2", "mlp.fc2", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.k_proj", "self_attn.k_proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.out_proj", "self_attn.out_proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.q_proj", "self_attn.q_proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.v_proj", "self_attn.v_proj", parent=k)]
|
||||
|
||||
return keys
|
||||
84
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_flux_lora.py
vendored
Normal file
84
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_flux_lora.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
|
||||
|
||||
|
||||
def __map_double_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("img_attn.qkv.0", "attn.to_q", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_attn.qkv.1", "attn.to_k", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_attn.qkv.2", "attn.to_v", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("txt_attn.qkv.0", "attn.add_q_proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_attn.qkv.1", "attn.add_k_proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_attn.qkv.2", "attn.add_v_proj", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("img_attn.proj", "attn.to_out.0", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_mlp.0", "ff.net.0.proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_mlp.2", "ff.net.2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_mod.lin", "norm1.linear", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("txt_attn.proj", "attn.to_add_out", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_mlp.0", "ff_context.net.0.proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_mlp.2", "ff_context.net.2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_mod.lin", "norm1_context.linear", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_single_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("linear1.0", "attn.to_q", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("linear1.1", "attn.to_k", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("linear1.2", "attn.to_v", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("linear1.3", "proj_mlp", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("linear2", "proj_out", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("modulation.lin", "norm.linear", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
|
||||
keys += [
|
||||
LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)
|
||||
]
|
||||
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
|
||||
keys += [
|
||||
LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)
|
||||
]
|
||||
keys += [
|
||||
LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)
|
||||
]
|
||||
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("time_in.out_layer", "time_text_embed.timestep_embedder.linear_2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_in.proj", "x_embedder", parent=key_prefix)]
|
||||
|
||||
for k in map_prefix_range("double_blocks", "transformer_blocks", parent=key_prefix):
|
||||
keys += __map_double_transformer_block(k)
|
||||
|
||||
for k in map_prefix_range("single_blocks", "single_transformer_blocks", parent=key_prefix):
|
||||
keys += __map_single_transformer_block(k)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def convert_flux_lora_key_sets() -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
|
||||
keys += __map_transformer(LoraConversionKeySet("transformer", "lora_transformer"))
|
||||
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
|
||||
keys += map_t5(LoraConversionKeySet("t5", "lora_te2"))
|
||||
|
||||
return keys
|
||||
217
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_lora_util.py
vendored
Normal file
217
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_lora_util.py
vendored
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class LoraConversionKeySet:
|
||||
def __init__(
|
||||
self,
|
||||
omi_prefix: str,
|
||||
diffusers_prefix: str,
|
||||
legacy_diffusers_prefix: str | None = None,
|
||||
parent: Self | None = None,
|
||||
swap_chunks: bool = False,
|
||||
filter_is_last: bool | None = None,
|
||||
next_omi_prefix: str | None = None,
|
||||
next_diffusers_prefix: str | None = None,
|
||||
):
|
||||
if parent is not None:
|
||||
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
|
||||
self.diffusers_prefix = combine(parent.diffusers_prefix, diffusers_prefix)
|
||||
else:
|
||||
self.omi_prefix = omi_prefix
|
||||
self.diffusers_prefix = diffusers_prefix
|
||||
|
||||
if legacy_diffusers_prefix is None:
|
||||
self.legacy_diffusers_prefix = self.diffusers_prefix.replace(".", "_")
|
||||
elif parent is not None:
|
||||
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace(
|
||||
".", "_"
|
||||
)
|
||||
else:
|
||||
self.legacy_diffusers_prefix = legacy_diffusers_prefix
|
||||
|
||||
self.parent = parent
|
||||
self.swap_chunks = swap_chunks
|
||||
self.filter_is_last = filter_is_last
|
||||
self.prefix = parent
|
||||
|
||||
if next_omi_prefix is None and parent is not None:
|
||||
self.next_omi_prefix = parent.next_omi_prefix
|
||||
self.next_diffusers_prefix = parent.next_diffusers_prefix
|
||||
self.next_legacy_diffusers_prefix = parent.next_legacy_diffusers_prefix
|
||||
elif next_omi_prefix is not None and parent is not None:
|
||||
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
|
||||
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
|
||||
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace(
|
||||
".", "_"
|
||||
)
|
||||
elif next_omi_prefix is not None and parent is None:
|
||||
self.next_omi_prefix = next_omi_prefix
|
||||
self.next_diffusers_prefix = next_diffusers_prefix
|
||||
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace(".", "_")
|
||||
else:
|
||||
self.next_omi_prefix = None
|
||||
self.next_diffusers_prefix = None
|
||||
self.next_legacy_diffusers_prefix = None
|
||||
|
||||
def __get_omi(self, in_prefix: str, key: str) -> str:
|
||||
return self.omi_prefix + key.removeprefix(in_prefix)
|
||||
|
||||
def __get_diffusers(self, in_prefix: str, key: str) -> str:
|
||||
return self.diffusers_prefix + key.removeprefix(in_prefix)
|
||||
|
||||
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
|
||||
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
|
||||
|
||||
suffix = key[key.rfind(".") :]
|
||||
if suffix not in [".alpha", ".dora_scale"]: # some keys only have a single . in the suffix
|
||||
suffix = key[key.removesuffix(suffix).rfind(".") :]
|
||||
key = key.removesuffix(suffix)
|
||||
|
||||
return key.replace(".", "_") + suffix
|
||||
|
||||
def get_key(self, in_prefix: str, key: str, target: str) -> str:
|
||||
if target == "omi":
|
||||
return self.__get_omi(in_prefix, key)
|
||||
elif target == "diffusers":
|
||||
return self.__get_diffusers(in_prefix, key)
|
||||
elif target == "legacy_diffusers":
|
||||
return self.__get_legacy_diffusers(in_prefix, key)
|
||||
return key
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"omi: {self.omi_prefix}, diffusers: {self.diffusers_prefix}, legacy: {self.legacy_diffusers_prefix}"
|
||||
|
||||
|
||||
def combine(left: str, right: str) -> str:
|
||||
left = left.rstrip(".")
|
||||
right = right.lstrip(".")
|
||||
if left == "" or left is None:
|
||||
return right
|
||||
elif right == "" or right is None:
|
||||
return left
|
||||
else:
|
||||
return left + "." + right
|
||||
|
||||
|
||||
def map_prefix_range(
|
||||
omi_prefix: str,
|
||||
diffusers_prefix: str,
|
||||
parent: LoraConversionKeySet,
|
||||
) -> list[LoraConversionKeySet]:
|
||||
# 100 should be a safe upper bound. increase if it's not enough in the future
|
||||
return [
|
||||
LoraConversionKeySet(
|
||||
omi_prefix=f"{omi_prefix}.{i}",
|
||||
diffusers_prefix=f"{diffusers_prefix}.{i}",
|
||||
parent=parent,
|
||||
next_omi_prefix=f"{omi_prefix}.{i + 1}",
|
||||
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
|
||||
def __convert(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
source: str,
|
||||
target: str,
|
||||
) -> dict[str, Tensor]:
|
||||
out_states = {}
|
||||
|
||||
if source == target:
|
||||
return dict(state_dict)
|
||||
|
||||
# TODO: maybe replace with a non O(n^2) algorithm
|
||||
for key, tensor in state_dict.items():
|
||||
for key_set in key_sets:
|
||||
in_prefix = ""
|
||||
|
||||
if source == "omi":
|
||||
in_prefix = key_set.omi_prefix
|
||||
elif source == "diffusers":
|
||||
in_prefix = key_set.diffusers_prefix
|
||||
elif source == "legacy_diffusers":
|
||||
in_prefix = key_set.legacy_diffusers_prefix
|
||||
|
||||
if not key.startswith(in_prefix):
|
||||
continue
|
||||
|
||||
if key_set.filter_is_last is not None:
|
||||
next_prefix = None
|
||||
if source == "omi":
|
||||
next_prefix = key_set.next_omi_prefix
|
||||
elif source == "diffusers":
|
||||
next_prefix = key_set.next_diffusers_prefix
|
||||
elif source == "legacy_diffusers":
|
||||
next_prefix = key_set.next_legacy_diffusers_prefix
|
||||
|
||||
is_last = not any(k.startswith(next_prefix) for k in state_dict)
|
||||
if key_set.filter_is_last != is_last:
|
||||
continue
|
||||
|
||||
name = key_set.get_key(in_prefix, key, target)
|
||||
|
||||
can_swap_chunks = target == "omi" or source == "omi"
|
||||
if key_set.swap_chunks and name.endswith(".lora_up.weight") and can_swap_chunks:
|
||||
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
|
||||
tensor = torch.cat([chunk_1, chunk_0], dim=0)
|
||||
|
||||
out_states[name] = tensor
|
||||
|
||||
break # only map the first matching key set
|
||||
|
||||
return out_states
|
||||
|
||||
|
||||
def __detect_source(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> str:
|
||||
omi_count = 0
|
||||
diffusers_count = 0
|
||||
legacy_diffusers_count = 0
|
||||
|
||||
for key in state_dict:
|
||||
for key_set in key_sets:
|
||||
if key.startswith(key_set.omi_prefix):
|
||||
omi_count += 1
|
||||
if key.startswith(key_set.diffusers_prefix):
|
||||
diffusers_count += 1
|
||||
if key.startswith(key_set.legacy_diffusers_prefix):
|
||||
legacy_diffusers_count += 1
|
||||
|
||||
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
|
||||
return "omi"
|
||||
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
|
||||
return "diffusers"
|
||||
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
|
||||
return "legacy_diffusers"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def convert_to_omi(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> dict[str, Tensor]:
|
||||
source = __detect_source(state_dict, key_sets)
|
||||
return __convert(state_dict, key_sets, source, "omi")
|
||||
|
||||
|
||||
def convert_to_diffusers(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> dict[str, Tensor]:
|
||||
source = __detect_source(state_dict, key_sets)
|
||||
return __convert(state_dict, key_sets, source, "diffusers")
|
||||
|
||||
|
||||
def convert_to_legacy_diffusers(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> dict[str, Tensor]:
|
||||
source = __detect_source(state_dict, key_sets)
|
||||
return __convert(state_dict, key_sets, source, "legacy_diffusers")
|
||||
125
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_sdxl_lora.py
vendored
Normal file
125
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_sdxl_lora.py
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
|
||||
|
||||
def __map_unet_resnet_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("emb_layers.1", "time_emb_proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("in_layers.2", "conv1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("out_layers.3", "conv2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("skip_connection", "conv_shortcut", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_attention_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("proj_in", "proj_in", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("proj_out", "proj_out", parent=key_prefix)]
|
||||
for k in map_prefix_range("transformer_blocks", "transformer_blocks", parent=key_prefix):
|
||||
keys += [LoraConversionKeySet("attn1.to_q", "attn1.to_q", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn1.to_k", "attn1.to_k", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn1.to_v", "attn1.to_v", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn1.to_out.0", "attn1.to_out.0", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_q", "attn2.to_q", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_k", "attn2.to_k", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_v", "attn2.to_v", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_out.0", "attn2.to_out.0", parent=k)]
|
||||
keys += [LoraConversionKeySet("ff.net.0.proj", "ff.net.0.proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("ff.net.2", "ff.net.2", parent=k)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_down_blocks(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.1", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("3.0.op", "0.downsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.1", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("6.0.op", "1.downsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("7.1", "2.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("8.1", "2.attentions.1", parent=key_prefix))
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_mid_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("0", "resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("1", "attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("2", "resnets.1", parent=key_prefix))
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_up_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("0.0", "0.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("0.1", "0.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("1.1", "0.attentions.1", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.2", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("2.1", "0.attentions.2", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("2.2.conv", "0.upsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("3.0", "1.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("3.1", "1.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.1", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.2", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.2", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("5.2.conv", "1.upsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("6.0", "2.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.2", parent=key_prefix))
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("input_blocks.0.0", "conv_in", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("time_embed.0", "time_embedding.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("time_embed.2", "time_embedding.linear_2", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("label_emb.0.0", "add_embedding.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("label_emb.0.2", "add_embedding.linear_2", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_down_blocks(LoraConversionKeySet("input_blocks", "down_blocks", parent=key_prefix))
|
||||
keys += __map_unet_mid_block(LoraConversionKeySet("middle_block", "mid_block", parent=key_prefix))
|
||||
keys += __map_unet_up_block(LoraConversionKeySet("output_blocks", "up_blocks", parent=key_prefix))
|
||||
|
||||
keys += [LoraConversionKeySet("out.0", "conv_norm_out", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("out.2", "conv_out", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def convert_sdxl_lora_key_sets() -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
|
||||
keys += __map_unet(LoraConversionKeySet("unet", "lora_unet"))
|
||||
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
|
||||
keys += map_clip(LoraConversionKeySet("clip_g", "lora_te2"))
|
||||
|
||||
return keys
|
||||
19
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_t5.py
vendored
Normal file
19
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_t5.py
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
|
||||
|
||||
def map_t5(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
for k in map_prefix_range("encoder.block", "encoder.block", parent=key_prefix):
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.k", "layer.0.SelfAttention.k", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.o", "layer.0.SelfAttention.o", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.q", "layer.0.SelfAttention.q", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.v", "layer.0.SelfAttention.v", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_0", "layer.1.DenseReluDense.wi_0", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_1", "layer.1.DenseReluDense.wi_1", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wo", "layer.1.DenseReluDense.wo", parent=k)]
|
||||
|
||||
return keys
|
||||
0
invokeai/backend/model_manager/omi/vendor/model_spec/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/model_spec/__init__.py
vendored
Normal file
31
invokeai/backend/model_manager/omi/vendor/model_spec/architecture.py
vendored
Normal file
31
invokeai/backend/model_manager/omi/vendor/model_spec/architecture.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
stable_diffusion_1_lora = "stable-diffusion-v1/lora"
|
||||
stable_diffusion_1_inpainting_lora = "stable-diffusion-v1-inpainting/lora"
|
||||
|
||||
stable_diffusion_2_512_lora = "stable-diffusion-v2-512/lora"
|
||||
stable_diffusion_2_768_v_lora = "stable-diffusion-v2-768-v/lora"
|
||||
stable_diffusion_2_depth_lora = "stable-diffusion-v2-depth/lora"
|
||||
stable_diffusion_2_inpainting_lora = "stable-diffusion-v2-inpainting/lora"
|
||||
|
||||
stable_diffusion_3_medium_lora = "stable-diffusion-v3-medium/lora"
|
||||
stable_diffusion_35_medium_lora = "stable-diffusion-v3.5-medium/lora"
|
||||
stable_diffusion_35_large_lora = "stable-diffusion-v3.5-large/lora"
|
||||
|
||||
stable_diffusion_xl_1_lora = "stable-diffusion-xl-v1-base/lora"
|
||||
stable_diffusion_xl_1_inpainting_lora = "stable-diffusion-xl-v1-base-inpainting/lora"
|
||||
|
||||
wuerstchen_2_lora = "wuerstchen-v2-prior/lora"
|
||||
stable_cascade_1_stage_a_lora = "stable-cascade-v1-stage-a/lora"
|
||||
stable_cascade_1_stage_b_lora = "stable-cascade-v1-stage-b/lora"
|
||||
stable_cascade_1_stage_c_lora = "stable-cascade-v1-stage-c/lora"
|
||||
|
||||
pixart_alpha_lora = "pixart-alpha/lora"
|
||||
pixart_sigma_lora = "pixart-sigma/lora"
|
||||
|
||||
flux_dev_1_lora = "Flux.1-dev/lora"
|
||||
flux_fill_dev_1_lora = "Flux.1-fill-dev/lora"
|
||||
|
||||
sana_lora = "sana/lora"
|
||||
|
||||
hunyuan_video_lora = "hunyuan-video/lora"
|
||||
|
||||
hi_dream_i1_lora = "hidream-i1/lora"
|
||||
@@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies):
|
||||
dependencies: Optional[list[StarterModelWithoutDependencies]] = None
|
||||
|
||||
|
||||
class StarterModelBundles(BaseModel):
|
||||
class StarterModelBundle(BaseModel):
|
||||
name: str
|
||||
models: list[StarterModel]
|
||||
|
||||
@@ -109,7 +109,7 @@ flux_vae = StarterModel(
|
||||
|
||||
# region: Main
|
||||
flux_schnell_quantized = StarterModel(
|
||||
name="FLUX Schnell (Quantized)",
|
||||
name="FLUX.1 schnell (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
|
||||
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
@@ -117,7 +117,7 @@ flux_schnell_quantized = StarterModel(
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_dev_quantized = StarterModel(
|
||||
name="FLUX Dev (Quantized)",
|
||||
name="FLUX.1 dev (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
|
||||
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
@@ -125,7 +125,7 @@ flux_dev_quantized = StarterModel(
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_schnell = StarterModel(
|
||||
name="FLUX Schnell",
|
||||
name="FLUX.1 schnell",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
|
||||
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
@@ -133,13 +133,29 @@ flux_schnell = StarterModel(
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_dev = StarterModel(
|
||||
name="FLUX Dev",
|
||||
name="FLUX.1 dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
|
||||
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext = StarterModel(
|
||||
name="FLUX.1 Kontext dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
|
||||
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext_quantized = StarterModel(
|
||||
name="FLUX.1 Kontext dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
@@ -656,6 +672,7 @@ flux_fill = StarterModel(
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
flux_kontext_quantized,
|
||||
flux_schnell_quantized,
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
@@ -776,12 +793,13 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_depth_control_lora,
|
||||
flux_redux,
|
||||
flux_fill,
|
||||
flux_kontext_quantized,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
|
||||
BaseModelType.StableDiffusion1: sd1_bundle,
|
||||
BaseModelType.StableDiffusionXL: sdxl_bundle,
|
||||
BaseModelType.Flux: flux_bundle,
|
||||
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
|
||||
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
|
||||
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
|
||||
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
|
||||
}
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
@@ -29,6 +29,7 @@ class BaseModelType(str, Enum):
|
||||
Imagen3 = "imagen3"
|
||||
Imagen4 = "imagen4"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
FluxKontext = "flux-kontext"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
@@ -88,6 +89,7 @@ class ModelVariantType(str, Enum):
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
OMI = "omi"
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
LyCORIS = "lycoris"
|
||||
|
||||
@@ -9,13 +9,25 @@ module.exports = {
|
||||
// https://github.com/qdanik/eslint-plugin-path
|
||||
'path/no-relative-imports': ['error', { maxDepth: 0 }],
|
||||
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
|
||||
'i18next/no-literal-string': 'error',
|
||||
// TODO: ENABLE THIS RULE BEFORE v6.0.0
|
||||
// 'i18next/no-literal-string': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-console
|
||||
'no-console': 'error',
|
||||
'no-console': 'warn',
|
||||
// https://eslint.org/docs/latest/rules/no-promise-executor-return
|
||||
'no-promise-executor-return': 'error',
|
||||
// https://eslint.org/docs/latest/rules/require-await
|
||||
'require-await': 'error',
|
||||
// Restrict setActiveTab calls to only use-navigation-api.tsx
|
||||
'no-restricted-syntax': [
|
||||
'error',
|
||||
{
|
||||
selector: 'CallExpression[callee.name="setActiveTab"]',
|
||||
message:
|
||||
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
|
||||
},
|
||||
],
|
||||
// TODO: ENABLE THIS RULE BEFORE v6.0.0
|
||||
'react/display-name': 'off',
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
@@ -30,8 +42,38 @@ module.exports = {
|
||||
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
|
||||
},
|
||||
],
|
||||
'no-restricted-imports': [
|
||||
'error',
|
||||
{
|
||||
paths: [
|
||||
{
|
||||
name: 'lodash-es',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'lodash-es',
|
||||
message: 'Please use es-toolkit instead.',
|
||||
},
|
||||
{
|
||||
name: 'es-toolkit',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
* Allow setActiveTab calls only in use-navigation-api.tsx
|
||||
*/
|
||||
{
|
||||
files: ['**/use-navigation-api.tsx'],
|
||||
rules: {
|
||||
'no-restricted-syntax': 'off',
|
||||
},
|
||||
},
|
||||
/**
|
||||
* Overrides for stories
|
||||
*/
|
||||
|
||||
@@ -12,10 +12,8 @@ const config: KnipConfig = {
|
||||
'src/features/parameters/types/parameterSchemas.ts',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// TODO(psyche): restore HRF functionality?
|
||||
'src/features/hrf/**',
|
||||
// This feature is (temprarily?) disabled
|
||||
'src/features/controlLayers/components/InpaintMask/InpaintMaskAddButtons.tsx',
|
||||
// Will be using this
|
||||
'src/common/hooks/useAsyncState.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
|
||||
@@ -38,70 +38,60 @@
|
||||
"test:ui": "vitest --coverage --ui",
|
||||
"test:no-watch": "vitest --no-watch"
|
||||
},
|
||||
"madge": {
|
||||
"excludeRegExp": [
|
||||
"^index.ts$"
|
||||
],
|
||||
"detectiveOptions": {
|
||||
"ts": {
|
||||
"skipTypeImports": true
|
||||
},
|
||||
"tsx": {
|
||||
"skipTypeImports": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"@atlaskit/pragmatic-drag-and-drop": "^1.5.3",
|
||||
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@atlaskit/pragmatic-drag-and-drop": "^1.7.4",
|
||||
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.1",
|
||||
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.1.0",
|
||||
"@dagrejs/dagre": "^1.1.5",
|
||||
"@dagrejs/graphlib": "^2.2.4",
|
||||
"@fontsource-variable/inter": "^5.2.5",
|
||||
"@fontsource-variable/inter": "^5.2.6",
|
||||
"@invoke-ai/ui-library": "^0.0.46",
|
||||
"@nanostores/react": "^1.0.0",
|
||||
"@reduxjs/toolkit": "2.7.0",
|
||||
"@observ33r/object-equals": "^1.1.4",
|
||||
"@reduxjs/toolkit": "2.8.2",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.6.0",
|
||||
"@xyflow/react": "^12.7.1",
|
||||
"ag-psd": "^28.2.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.1.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"dockview": "^4.4.0",
|
||||
"es-toolkit": "^1.39.5",
|
||||
"filesize": "^10.1.6",
|
||||
"fracturedjsonjs": "^4.1.0",
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.0.1",
|
||||
"i18next": "^25.2.1",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "^6.2.1",
|
||||
"idb-keyval": "^6.2.2",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.20",
|
||||
"linkify-react": "^4.2.0",
|
||||
"linkifyjs": "^4.2.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"linkify-react": "^4.3.1",
|
||||
"linkifyjs": "^4.3.1",
|
||||
"lru-cache": "^11.1.0",
|
||||
"mtwist": "^1.0.2",
|
||||
"nanoid": "^5.1.5",
|
||||
"nanostores": "^1.0.1",
|
||||
"new-github-issue-url": "^1.1.0",
|
||||
"overlayscrollbars": "^2.11.1",
|
||||
"overlayscrollbars": "^2.11.4",
|
||||
"overlayscrollbars-react": "^0.5.6",
|
||||
"perfect-freehand": "^1.2.2",
|
||||
"query-string": "^9.1.1",
|
||||
"query-string": "^9.2.1",
|
||||
"raf-throttle": "^2.0.6",
|
||||
"react": "^18.3.1",
|
||||
"react-colorful": "^5.6.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-error-boundary": "^5.0.0",
|
||||
"react-hook-form": "^7.56.1",
|
||||
"react-hook-form": "^7.58.1",
|
||||
"react-hotkeys-hook": "4.5.0",
|
||||
"react-i18next": "^15.5.1",
|
||||
"react-i18next": "^15.5.3",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-redux": "9.2.0",
|
||||
"react-resizable-panels": "^2.1.8",
|
||||
"react-resizable-panels": "^3.0.3",
|
||||
"react-textarea-autosize": "^8.5.9",
|
||||
"react-use": "^17.6.0",
|
||||
"react-virtuoso": "^4.12.6",
|
||||
"react-virtuoso": "^4.13.0",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-remember": "^5.2.0",
|
||||
"redux-undo": "^1.1.0",
|
||||
@@ -109,12 +99,12 @@
|
||||
"roarr": "^7.21.1",
|
||||
"serialize-error": "^12.0.0",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"stable-hash": "^0.0.5",
|
||||
"use-debounce": "^10.0.4",
|
||||
"stable-hash": "^0.0.6",
|
||||
"use-debounce": "^10.0.5",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^3.24.3",
|
||||
"zod-validation-error": "^3.4.0"
|
||||
"zod": "^3.25.67",
|
||||
"zod-validation-error": "^3.5.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^18.2.0",
|
||||
@@ -131,7 +121,6 @@
|
||||
"@storybook/react": "^8.6.12",
|
||||
"@storybook/react-vite": "^8.6.12",
|
||||
"@storybook/theming": "^8.6.12",
|
||||
"@types/lodash-es": "^4.17.12",
|
||||
"@types/node": "^22.15.1",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
@@ -145,7 +134,7 @@
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-plugin-i18next": "^6.1.1",
|
||||
"eslint-plugin-path": "^1.3.0",
|
||||
"knip": "^5.50.5",
|
||||
"knip": "^5.61.3",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^7.6.1",
|
||||
"prettier": "^3.5.3",
|
||||
@@ -154,7 +143,7 @@
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.40.0",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^6.3.3",
|
||||
"vite": "^7.0.2",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||
"vite-plugin-dts": "^4.5.3",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
@@ -162,7 +151,7 @@
|
||||
"vitest": "^3.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": "8"
|
||||
"pnpm": "10"
|
||||
},
|
||||
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
|
||||
"packageManager": "pnpm@10.12.4"
|
||||
}
|
||||
|
||||
12920
invokeai/frontend/web/pnpm-lock.yaml
generated
12920
invokeai/frontend/web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
3
invokeai/frontend/web/pnpm-workspace.yaml
Normal file
3
invokeai/frontend/web/pnpm-workspace.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
onlyBuiltDependencies:
|
||||
- '@swc/core'
|
||||
- esbuild
|
||||
@@ -225,7 +225,16 @@
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Add Prompt Trigger",
|
||||
"compatibleEmbeddings": "Compatible Embeddings",
|
||||
"noMatchingTriggers": "No matching triggers"
|
||||
"noMatchingTriggers": "No matching triggers",
|
||||
"generateFromImage": "Generate prompt from image",
|
||||
"expandCurrentPrompt": "Expand Current Prompt",
|
||||
"uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
|
||||
"expandingPrompt": "Expanding prompt...",
|
||||
"resultTitle": "Prompt Expansion Complete",
|
||||
"resultSubtitle": "Choose how to handle the expanded prompt:",
|
||||
"replace": "Replace",
|
||||
"insert": "Insert",
|
||||
"discard": "Discard"
|
||||
},
|
||||
"queue": {
|
||||
"queue": "Queue",
|
||||
@@ -335,14 +344,14 @@
|
||||
"images": "Images",
|
||||
"assets": "Assets",
|
||||
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
|
||||
"assetsTab": "Files you’ve uploaded for use in your projects.",
|
||||
"assetsTab": "Files you've uploaded for use in your projects.",
|
||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
|
||||
"autoSwitchNewImages": "Auto-Switch to New Images",
|
||||
"boardsSettings": "Boards Settings",
|
||||
"copy": "Copy",
|
||||
"currentlyInUse": "This image is currently in use in the following features:",
|
||||
"drop": "Drop",
|
||||
"dropOrUpload": "$t(gallery.drop) or Upload",
|
||||
"dropOrUpload": "Drop or Upload",
|
||||
"dropToUpload": "$t(gallery.drop) to Upload",
|
||||
"deleteImage_one": "Delete Image",
|
||||
"deleteImage_other": "Delete {{count}} Images",
|
||||
@@ -357,7 +366,7 @@
|
||||
"gallerySettings": "Gallery Settings",
|
||||
"go": "Go",
|
||||
"image": "image",
|
||||
"imagesTab": "Images you’ve created and saved within Invoke.",
|
||||
"imagesTab": "Images you've created and saved within Invoke.",
|
||||
"imagesSettings": "Gallery Images Settings",
|
||||
"jump": "Jump",
|
||||
"loading": "Loading",
|
||||
@@ -396,7 +405,8 @@
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
|
||||
"openViewer": "Open Viewer",
|
||||
"closeViewer": "Close Viewer",
|
||||
"move": "Move"
|
||||
"move": "Move",
|
||||
"useForPromptGeneration": "Use for Prompt Generation"
|
||||
},
|
||||
"hotkeys": {
|
||||
"hotkeys": "Hotkeys",
|
||||
@@ -579,6 +589,16 @@
|
||||
"cancelTransform": {
|
||||
"title": "Cancel Transform",
|
||||
"desc": "Cancel the pending transform."
|
||||
},
|
||||
"settings": {
|
||||
"behavior": "Behavior",
|
||||
"display": "Display",
|
||||
"grid": "Grid",
|
||||
"debug": "Debug"
|
||||
},
|
||||
"toggleNonRasterLayers": {
|
||||
"title": "Toggle Non-Raster Layers",
|
||||
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
|
||||
}
|
||||
},
|
||||
"workflows": {
|
||||
@@ -742,7 +762,7 @@
|
||||
"vae": "VAE",
|
||||
"width": "Width",
|
||||
"workflow": "Workflow",
|
||||
"canvasV2Metadata": "Canvas"
|
||||
"canvasV2Metadata": "Canvas Layers"
|
||||
},
|
||||
"modelManager": {
|
||||
"active": "active",
|
||||
@@ -763,7 +783,7 @@
|
||||
"convertToDiffusers": "Convert To Diffusers",
|
||||
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
|
||||
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in the InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
|
||||
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
|
||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||
@@ -806,7 +826,11 @@
|
||||
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
|
||||
"urlUnauthorizedErrorMessage2": "Learn how here.",
|
||||
"imageEncoderModelId": "Image Encoder Model ID",
|
||||
"includesNModels": "Includes {{n}} models and their dependencies",
|
||||
"installedModelsCount": "{{installed}} of {{total}} models installed.",
|
||||
"includesNModels": "Includes {{n}} models and their dependencies.",
|
||||
"allNModelsInstalled": "All {{count}} models installed",
|
||||
"nToInstall": "{{count}} to install",
|
||||
"nAlreadyInstalled": "{{count}} already installed",
|
||||
"installQueue": "Install Queue",
|
||||
"inplaceInstall": "In-place install",
|
||||
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
|
||||
@@ -869,6 +893,25 @@
|
||||
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"bundleAlreadyInstalled": "Bundle already installed",
|
||||
"bundleAlreadyInstalledDesc": "All models in the {{bundleName}} bundle are already installed.",
|
||||
"launchpadTab": "Launchpad",
|
||||
"launchpad": {
|
||||
"welcome": "Welcome to Model Management",
|
||||
"description": "Invoke requires models to be installed to utilize most features of the platform. Choose from manual installation options or explore curated starter models.",
|
||||
"manualInstall": "Manual Installation",
|
||||
"urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.",
|
||||
"huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.",
|
||||
"scanFolderDescription": "Scan a local folder to automatically detect and install models.",
|
||||
"recommendedModels": "Recommended Models",
|
||||
"exploreStarter": "Or browse all available starter models",
|
||||
"quickStart": "Quick Start Bundles",
|
||||
"bundleDescription": "Each bundle includes essential models for each model family and curated base models to get started.",
|
||||
"browseAll": "Or browse all available models:",
|
||||
"stableDiffusion15": "Stable Diffusion 1.5",
|
||||
"sdxl": "SDXL",
|
||||
"fluxDev": "FLUX.1 dev"
|
||||
},
|
||||
"controlLora": "Control LoRA",
|
||||
"llavaOnevision": "LLaVA OneVision",
|
||||
"syncModels": "Sync Models",
|
||||
@@ -905,7 +948,8 @@
|
||||
"selectModel": "Select a Model",
|
||||
"noLoRAsInstalled": "No LoRAs installed",
|
||||
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
|
||||
"defaultVAE": "Default VAE"
|
||||
"defaultVAE": "Default VAE",
|
||||
"noCompatibleLoRAs": "No Compatible LoRAs"
|
||||
},
|
||||
"nodes": {
|
||||
"arithmeticSequence": "Arithmetic Sequence",
|
||||
@@ -1147,6 +1191,7 @@
|
||||
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
|
||||
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
|
||||
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
|
||||
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
|
||||
"canvasIsFiltering": "Canvas is busy (filtering)",
|
||||
"canvasIsTransforming": "Canvas is busy (transforming)",
|
||||
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
|
||||
@@ -1154,7 +1199,9 @@
|
||||
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"systemDisconnected": "System disconnected"
|
||||
"systemDisconnected": "System disconnected",
|
||||
"promptExpansionPending": "Prompt expansion in progress",
|
||||
"promptExpansionResultPending": "Please accept or discard your prompt expansion result"
|
||||
},
|
||||
"maskBlur": "Mask Blur",
|
||||
"negativePromptPlaceholder": "Negative Prompt",
|
||||
@@ -1312,6 +1359,21 @@
|
||||
"problemCopyingLayer": "Unable to Copy Layer",
|
||||
"problemSavingLayer": "Unable to Save Layer",
|
||||
"problemDownloadingImage": "Unable to Download Image",
|
||||
"noRasterLayers": "No Raster Layers Found",
|
||||
"noRasterLayersDesc": "Create at least one raster layer to export to PSD",
|
||||
"noActiveRasterLayers": "No Active Raster Layers",
|
||||
"noActiveRasterLayersDesc": "Enable at least one raster layer to export to PSD",
|
||||
"noVisibleRasterLayers": "No Visible Raster Layers",
|
||||
"noVisibleRasterLayersDesc": "Enable at least one raster layer to export to PSD",
|
||||
"invalidCanvasDimensions": "Invalid Canvas Dimensions",
|
||||
"canvasTooLarge": "Canvas Too Large",
|
||||
"canvasTooLargeDesc": "Canvas dimensions exceed the maximum allowed size for PSD export. Reduce the total width and height of the canvas of the canvas and try again.",
|
||||
"failedToProcessLayers": "Failed to Process Layers",
|
||||
"psdExportSuccess": "PSD Export Complete",
|
||||
"psdExportSuccessDesc": "Successfully exported {{count}} layers to PSD file",
|
||||
"problemExportingPSD": "Problem Exporting PSD",
|
||||
"canvasManagerNotAvailable": "Canvas Manager Not Available",
|
||||
"noValidLayerAdapters": "No Valid Layer Adapters Found",
|
||||
"pasteSuccess": "Pasted to {{destination}}",
|
||||
"pasteFailed": "Paste Failed",
|
||||
"prunedQueue": "Pruned Queue",
|
||||
@@ -1337,9 +1399,15 @@
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
|
||||
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
|
||||
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
|
||||
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
|
||||
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
|
||||
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
|
||||
"workflowUnpublished": "Workflow Unpublished"
|
||||
"workflowUnpublished": "Workflow Unpublished",
|
||||
"sentToCanvas": "Sent to Canvas",
|
||||
"sentToUpscale": "Sent to Upscale",
|
||||
"promptGenerationStarted": "Prompt generation started",
|
||||
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
|
||||
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
|
||||
},
|
||||
"popovers": {
|
||||
"clipSkip": {
|
||||
@@ -1862,6 +1930,7 @@
|
||||
"saveCanvasToGallery": "Save Canvas to Gallery",
|
||||
"saveBboxToGallery": "Save Bbox to Gallery",
|
||||
"saveLayerToAssets": "Save Layer to Assets",
|
||||
"exportCanvasToPSD": "Export Canvas to PSD",
|
||||
"cropLayerToBbox": "Crop Layer to Bbox",
|
||||
"savedToGalleryOk": "Saved to Gallery",
|
||||
"savedToGalleryError": "Error saving to gallery",
|
||||
@@ -1887,11 +1956,13 @@
|
||||
"mergingLayers": "Merging layers",
|
||||
"clearHistory": "Clear History",
|
||||
"bboxOverlay": "Show Bbox Overlay",
|
||||
"ruleOfThirds": "Show Rule of Thirds",
|
||||
"newSession": "New Session",
|
||||
"clearCaches": "Clear Caches",
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"outputOnlyMaskedRegions": "Output Only Generated Regions",
|
||||
"saveAllImagesToGallery": "Save All Images to Gallery",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -1992,6 +2063,8 @@
|
||||
"disableTransparencyEffect": "Disable Transparency Effect",
|
||||
"hidingType": "Hiding {{type}}",
|
||||
"showingType": "Showing {{type}}",
|
||||
"showNonRasterLayers": "Show Non-Raster Layers (Shift+H)",
|
||||
"hideNonRasterLayers": "Hide Non-Raster Layers (Shift+H)",
|
||||
"dynamicGrid": "Dynamic Grid",
|
||||
"logDebugInfo": "Log Debug Info",
|
||||
"locked": "Locked",
|
||||
@@ -2015,7 +2088,9 @@
|
||||
"resetGenerationSettings": "Reset Generation Settings",
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or <PullBboxButton>pull the bounding box into this layer</PullBboxButton> to get started.",
|
||||
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
|
||||
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
|
||||
"imageNoise": "Image Noise",
|
||||
"denoiseLimit": "Denoise Limit",
|
||||
"warnings": {
|
||||
@@ -2256,6 +2331,9 @@
|
||||
"label": "Preserve Masked Region",
|
||||
"alert": "Preserving Masked Region"
|
||||
},
|
||||
"saveAllImagesToGallery": {
|
||||
"alert": "Saving All Images to Gallery"
|
||||
},
|
||||
"isolatedStagingPreview": "Isolated Staging Preview",
|
||||
"isolatedPreview": "Isolated Preview",
|
||||
"isolatedLayerPreview": "Isolated Layer Preview",
|
||||
@@ -2284,6 +2362,7 @@
|
||||
"newGlobalReferenceImage": "New Global Reference Image",
|
||||
"newRegionalReferenceImage": "New Regional Reference Image",
|
||||
"newControlLayer": "New Control Layer",
|
||||
"newResizedControlLayer": "New Resized Control Layer",
|
||||
"newRasterLayer": "New Raster Layer",
|
||||
"newInpaintMask": "New Inpaint Mask",
|
||||
"newRegionalGuidance": "New Regional Guidance",
|
||||
@@ -2301,6 +2380,11 @@
|
||||
"saveToGallery": "Save To Gallery",
|
||||
"showResultsOn": "Showing Results",
|
||||
"showResultsOff": "Hiding Results"
|
||||
},
|
||||
"autoSwitch": {
|
||||
"off": "Off",
|
||||
"switchOnStart": "On Start",
|
||||
"switchOnFinish": "On Finish"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
@@ -2367,7 +2451,8 @@
|
||||
"uploadImage": "Upload Image",
|
||||
"useForTemplate": "Use For Prompt Template",
|
||||
"viewList": "View Template List",
|
||||
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
|
||||
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
|
||||
"togglePromptPreviews": "Toggle Prompt Previews"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invite Teammates",
|
||||
@@ -2387,6 +2472,55 @@
|
||||
"upscaling": "Upscaling",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
|
||||
"gallery": "Gallery"
|
||||
},
|
||||
"launchpad": {
|
||||
"workflowsTitle": "Go deep with Workflows.",
|
||||
"upscalingTitle": "Upscale and add detail.",
|
||||
"canvasTitle": "Edit and refine on Canvas.",
|
||||
"generateTitle": "Generate images from text prompts.",
|
||||
"modelGuideText": "Want to learn what prompts work best for each model?",
|
||||
"modelGuideLink": "Check out our Model Guide.",
|
||||
"workflows": {
|
||||
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
|
||||
"learnMoreLink": "Learn more about creating workflows",
|
||||
"browseTemplates": {
|
||||
"title": "Browse Workflow Templates",
|
||||
"description": "Choose from pre-built workflows for common tasks"
|
||||
},
|
||||
"createNew": {
|
||||
"title": "Create a new Workflow",
|
||||
"description": "Start a new workflow from scratch"
|
||||
},
|
||||
"loadFromFile": {
|
||||
"title": "Load workflow from file",
|
||||
"description": "Upload a workflow to start with an existing setup"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
"uploadImage": {
|
||||
"title": "Upload Image to Upscale",
|
||||
"description": "Click or drag an image to upscale (JPG, PNG, WebP up to 100MB)"
|
||||
},
|
||||
"replaceImage": {
|
||||
"title": "Replace Current Image",
|
||||
"description": "Click or drag a new image to replace the current one"
|
||||
},
|
||||
"imageReady": {
|
||||
"title": "Image Ready",
|
||||
"description": "Press Invoke to begin upscaling"
|
||||
},
|
||||
"readyToUpscale": {
|
||||
"title": "Ready to upscale!",
|
||||
"description": "Configure your settings below, then click the Invoke button to begin upscaling your image."
|
||||
},
|
||||
"upscaleModel": "Upscale Model",
|
||||
"model": "Model",
|
||||
"scale": "Scale",
|
||||
"helpText": {
|
||||
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
|
||||
"styleAdvice": "Upscaling works best with the general style of your image."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"system": {
|
||||
@@ -2426,8 +2560,9 @@
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"items": [
|
||||
"Inpainting: Per-mask noise levels and denoise limits.",
|
||||
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
|
||||
"Generate images faster with new Launchpads and a simplified Generate tab.",
|
||||
"Edit with prompts using Flux Kontext Dev.",
|
||||
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
@@ -2436,62 +2571,16 @@
|
||||
"supportVideos": {
|
||||
"supportVideos": "Support Videos",
|
||||
"gettingStarted": "Getting Started",
|
||||
"controlCanvas": "Control Canvas",
|
||||
"watch": "Watch",
|
||||
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
|
||||
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
|
||||
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
|
||||
"videos": {
|
||||
"creatingYourFirstImage": {
|
||||
"title": "Creating Your First Image",
|
||||
"description": "Introduction to creating an image from scratch using Invoke's tools."
|
||||
"gettingStarted": {
|
||||
"title": "Getting Started with Invoke",
|
||||
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
|
||||
},
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"title": "Using Control Layers and Reference Guides",
|
||||
"description": "Learn how to guide your image creation with control layers and reference images."
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"title": "Understanding Image-to-Image and Denoising",
|
||||
"description": "Overview of image-to-image transformations and denoising in Invoke."
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"title": "Exploring AI Models and Concept Adapters",
|
||||
"description": "Dive into AI models and how to use concept adapters for creative control."
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"title": "Creating and Composing on Invoke's Control Canvas",
|
||||
"description": "Learn to compose images using Invoke's control canvas."
|
||||
},
|
||||
"upscaling": {
|
||||
"title": "Upscaling",
|
||||
"description": "How to upscale images with Invoke's tools to enhance resolution."
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"title": "How Do I Generate and Save to the Gallery?",
|
||||
"description": "Steps to generate and save images to the gallery."
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"title": "How Do I Edit on the Canvas?",
|
||||
"description": "Guide to editing images directly on the canvas."
|
||||
},
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"title": "How Do I Do Image-to-Image Transformation?",
|
||||
"description": "Tutorial on performing image-to-image transformations in Invoke."
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "How Do I Use Control Nets and Control Layers?",
|
||||
"description": "Learn to apply control layers and controlnets to your images."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"title": "How Do I Use Global IP Adapters and Reference Images?",
|
||||
"description": "Introduction to adding reference images and global IP adapters."
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "How Do I Use Inpaint Masks?",
|
||||
"description": "How to apply inpaint masks for image correction and variation."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"title": "How Do I Outpaint?",
|
||||
"description": "Guide to outpainting beyond the original image borders."
|
||||
"studioSessions": {
|
||||
"title": "Studio Sessions",
|
||||
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,7 @@ import { Box } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
|
||||
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
|
||||
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
|
||||
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||
@@ -12,6 +11,7 @@ import { memo, useCallback } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import ThemeLocaleProvider from './ThemeLocaleProvider';
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props {
|
||||
@@ -31,12 +31,14 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
<ThemeLocaleProvider>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
</ThemeLocaleProvider>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import { setupListeners } from '@reduxjs/toolkit/query';
|
||||
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||
@@ -8,19 +9,24 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||
import { useFocusRegionWatcher } from 'common/hooks/focus';
|
||||
import { useCloseChakraTooltipsOnDragFix } from 'common/hooks/useCloseChakraTooltipsOnDragFix';
|
||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
import { useDndMonitor } from 'features/dnd/useDndMonitor';
|
||||
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
|
||||
import { useReadinessWatcher } from 'features/queue/store/readiness';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||
import { useNavigationApi } from 'features/ui/layouts/use-navigation-api';
|
||||
import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
import { useGetQueueCountsByDestinationQuery } from 'services/api/endpoints/queue';
|
||||
import { useSocketIO } from 'services/events/useSocketIO';
|
||||
|
||||
const queueCountArg = { destination: 'canvas' };
|
||||
|
||||
/**
|
||||
* GlobalHookIsolator is a logical component that runs global hooks in an isolated component, so that they do not
|
||||
* cause needless re-renders of any other components.
|
||||
@@ -38,22 +44,31 @@ export const GlobalHookIsolator = memo(
|
||||
useGlobalHotkeys();
|
||||
useGetOpenAPISchemaQuery();
|
||||
useSyncLoggingConfig();
|
||||
useCloseChakraTooltipsOnDragFix();
|
||||
useNavigationApi();
|
||||
useDndMonitor();
|
||||
|
||||
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
|
||||
// and/or in progress canvas sessions.
|
||||
useGetQueueCountsByDestinationQuery(queueCountArg);
|
||||
|
||||
useEffect(() => {
|
||||
i18n.changeLanguage(language);
|
||||
}, [language]);
|
||||
|
||||
useEffect(() => {
|
||||
if (size(config)) {
|
||||
logger.info({ config }, 'Received config');
|
||||
dispatch(configChanged(config));
|
||||
}
|
||||
logger.info({ config }, 'Received config');
|
||||
dispatch(configChanged(config));
|
||||
}, [dispatch, config, logger]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(appStarted());
|
||||
}, [dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
return setupListeners(dispatch);
|
||||
}, [dispatch]);
|
||||
|
||||
useStudioInitAction(studioInitAction);
|
||||
useStarterModelsToast();
|
||||
useSyncQueueStatus();
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
|
||||
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
|
||||
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
|
||||
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
|
||||
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
|
||||
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const GlobalImageHotkeys = memo(() => {
|
||||
useAssertSingleton('GlobalImageHotkeys');
|
||||
const imageDTO = useAppSelector(selectLastSelectedImage);
|
||||
const imageName = useAppSelector(selectLastSelectedImage);
|
||||
const imageDTO = useImageDTO(imageName);
|
||||
|
||||
if (!imageDTO) {
|
||||
return null;
|
||||
@@ -25,59 +30,64 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
|
||||
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const isGalleryFocused = useIsRegionFocused('gallery');
|
||||
const isViewerFocused = useIsRegionFocused('viewer');
|
||||
const imageActions = useImageActions(imageDTO);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
|
||||
const isFocusOK = isGalleryFocused || isViewerFocused;
|
||||
|
||||
const recallAll = useRecallAll(imageDTO);
|
||||
const recallRemix = useRecallRemix(imageDTO);
|
||||
const recallPrompts = useRecallPrompts(imageDTO);
|
||||
const recallSeed = useRecallSeed(imageDTO);
|
||||
const recallDimensions = useRecallDimensions(imageDTO);
|
||||
const loadWorkflow = useLoadWorkflow(imageDTO);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'loadWorkflow',
|
||||
category: 'viewer',
|
||||
callback: imageActions.loadWorkflow,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
|
||||
callback: loadWorkflow.load,
|
||||
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
|
||||
dependencies: [loadWorkflow, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallAll',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallAll,
|
||||
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
|
||||
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
|
||||
callback: recallAll.recall,
|
||||
options: { enabled: recallAll.isEnabled && isFocusOK },
|
||||
dependencies: [recallAll, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallSeed',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallSeed,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
|
||||
callback: recallSeed.recall,
|
||||
options: { enabled: recallSeed.isEnabled && isFocusOK },
|
||||
dependencies: [recallSeed, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallPrompts',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallPrompts,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
|
||||
callback: recallPrompts.recall,
|
||||
options: { enabled: recallPrompts.isEnabled && isFocusOK },
|
||||
dependencies: [recallPrompts, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'remix',
|
||||
category: 'viewer',
|
||||
callback: imageActions.remix,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
|
||||
callback: recallRemix.recall,
|
||||
options: { enabled: recallRemix.isEnabled && isFocusOK },
|
||||
dependencies: [recallRemix, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'useSize',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallSize,
|
||||
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
|
||||
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
|
||||
});
|
||||
useRegisteredHotkeys({
|
||||
id: 'runPostprocessing',
|
||||
category: 'viewer',
|
||||
callback: imageActions.upscale,
|
||||
options: { enabled: isUpscalingEnabled && isViewerFocused },
|
||||
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
|
||||
callback: recallDimensions.recall,
|
||||
options: { enabled: recallDimensions.isEnabled && isFocusOK },
|
||||
dependencies: [recallDimensions, isFocusOK],
|
||||
});
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
NewGallerySessionDialog,
|
||||
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
@@ -15,6 +15,7 @@ import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/workflow
|
||||
import { WorkflowLibraryModal } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryModal';
|
||||
import { CancelAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { DeleteAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/DeleteAllExceptCurrentQueueItemConfirmationAlertDialog';
|
||||
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||
@@ -39,6 +40,7 @@ export const GlobalModalIsolator = memo(() => {
|
||||
<StylePresetModal />
|
||||
<WorkflowLibraryModal />
|
||||
<CancelAllExceptCurrentQueueItemConfirmationAlertDialog />
|
||||
<DeleteAllExceptCurrentQueueItemConfirmationAlertDialog />
|
||||
<ClearQueueConfirmationsAlertDialog />
|
||||
<NewWorkflowConfirmationAlertDialog />
|
||||
<LoadWorkflowConfirmationAlertDialog />
|
||||
|
||||
@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import '@fontsource-variable/inter';
|
||||
import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import '@xyflow/react/dist/base.css';
|
||||
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
|
||||
|
||||
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import type { ReactNode } from 'react';
|
||||
|
||||
@@ -3,13 +3,12 @@ import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import { settingsSendToCanvasChanged } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { sentImageToCanvas } from 'features/gallery/store/actions';
|
||||
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
|
||||
import { MetadataUtils } from 'features/metadata/parsing';
|
||||
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
|
||||
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
|
||||
import {
|
||||
@@ -20,7 +19,9 @@ import {
|
||||
} from 'features/nodes/store/workflowLibrarySlice';
|
||||
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
@@ -91,12 +92,10 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
objects: [imageObject],
|
||||
};
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
store.dispatch(canvasReset());
|
||||
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
store.dispatch(settingsSendToCanvasChanged(true));
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
store.dispatch(sentImageToCanvas());
|
||||
$imageViewer.set(false);
|
||||
toast({
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'info',
|
||||
@@ -118,25 +117,25 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
return;
|
||||
}
|
||||
const metadata = getImageMetadataResult.value;
|
||||
store.dispatch(canvasReset());
|
||||
// This shows a toast
|
||||
await parseAndRecallAllMetadata(metadata, true);
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
await MetadataUtils.recallAll(metadata, store);
|
||||
},
|
||||
[store, t]
|
||||
);
|
||||
|
||||
const handleLoadWorkflow = useCallback(
|
||||
async (workflowId: string) => {
|
||||
(workflowId: string) => {
|
||||
// This shows a toast
|
||||
await loadWorkflowWithDialog({
|
||||
loadWorkflowWithDialog({
|
||||
type: 'library',
|
||||
data: workflowId,
|
||||
onSuccess: () => {
|
||||
store.dispatch(setActiveTab('workflows'));
|
||||
navigationApi.switchToTab('workflows');
|
||||
},
|
||||
});
|
||||
},
|
||||
[loadWorkflowWithDialog, store]
|
||||
[loadWorkflowWithDialog]
|
||||
);
|
||||
|
||||
const handleSelectStylePreset = useCallback(
|
||||
@@ -150,7 +149,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
return;
|
||||
}
|
||||
store.dispatch(activeStylePresetIdChanged(stylePresetId));
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
navigationApi.switchToTab('canvas');
|
||||
toast({
|
||||
title: t('toast.stylePresetLoaded'),
|
||||
status: 'info',
|
||||
@@ -160,37 +159,34 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
);
|
||||
|
||||
const handleGoToDestination = useCallback(
|
||||
(destination: StudioDestinationAction['data']['destination']) => {
|
||||
async (destination: StudioDestinationAction['data']['destination']) => {
|
||||
switch (destination) {
|
||||
case 'generation':
|
||||
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
// Go to the generate tab, open the launchpad
|
||||
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
|
||||
store.dispatch(paramsReset());
|
||||
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
store.dispatch(settingsSendToCanvasChanged(false));
|
||||
$imageViewer.set(true);
|
||||
break;
|
||||
case 'canvas':
|
||||
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
store.dispatch(settingsSendToCanvasChanged(true));
|
||||
$imageViewer.set(false);
|
||||
// Go to the canvas tab, open the launchpad
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
break;
|
||||
case 'workflows':
|
||||
// Go to the workflows tab
|
||||
store.dispatch(setActiveTab('workflows'));
|
||||
navigationApi.switchToTab('workflows');
|
||||
break;
|
||||
case 'upscaling':
|
||||
// Go to the upscaling tab
|
||||
store.dispatch(setActiveTab('upscaling'));
|
||||
navigationApi.switchToTab('upscaling');
|
||||
break;
|
||||
case 'viewAllWorkflows':
|
||||
// Go to the workflows tab and open the workflow library modal
|
||||
store.dispatch(setActiveTab('workflows'));
|
||||
navigationApi.switchToTab('workflows');
|
||||
$isWorkflowLibraryModalOpen.set(true);
|
||||
break;
|
||||
case 'viewAllWorkflowsRecommended':
|
||||
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
|
||||
store.dispatch(setActiveTab('workflows'));
|
||||
navigationApi.switchToTab('workflows');
|
||||
$isWorkflowLibraryModalOpen.set(true);
|
||||
store.dispatch(workflowLibraryViewChanged('defaults'));
|
||||
store.dispatch(workflowLibraryTagsReset());
|
||||
@@ -202,7 +198,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
break;
|
||||
case 'viewAllStylePresets':
|
||||
// Go to the canvas tab and open the style presets menu
|
||||
store.dispatch(setActiveTab('canvas'));
|
||||
navigationApi.switchToTab('canvas');
|
||||
$isStylePresetsMenuOpen.set(true);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Logger, MessageSerializer } from 'roarr';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
import { z } from 'zod';
|
||||
import { z } from 'zod/v4';
|
||||
|
||||
const serializeMessage: MessageSerializer = (message) => {
|
||||
return JSON.stringify(message);
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { objectEquals } from '@observ33r/object-equals';
|
||||
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
/**
|
||||
* A memoized selector creator that uses LRU cache and lodash's isEqual for equality check.
|
||||
* A memoized selector creator that uses LRU cache and @observ33r/object-equals's objectEquals for equality check.
|
||||
*/
|
||||
export const createMemoizedSelector = createSelectorCreator({
|
||||
memoize: lruMemoize,
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
resultEqualityCheck: objectEquals,
|
||||
},
|
||||
argsMemoize: lruMemoize,
|
||||
});
|
||||
|
||||
@@ -8,10 +8,13 @@ import { diff } from 'jsondiffpatch';
|
||||
* Super simple logger middleware. Useful for debugging when the redux devtools are awkward.
|
||||
*/
|
||||
export const getDebugLoggerMiddleware =
|
||||
(options?: { withDiff?: boolean; withNextState?: boolean }): Middleware =>
|
||||
(options?: { filter?: (action: unknown) => boolean; withDiff?: boolean; withNextState?: boolean }): Middleware =>
|
||||
(api: MiddlewareAPI) =>
|
||||
(next) =>
|
||||
(action) => {
|
||||
if (options?.filter?.(action)) {
|
||||
return next(action);
|
||||
}
|
||||
const originalState = api.getState();
|
||||
console.log('REDUX: dispatching', action);
|
||||
const result = next(action);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
@@ -9,16 +8,9 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
|
||||
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
|
||||
import { addImageToDeleteSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected';
|
||||
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
@@ -27,7 +19,6 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
|
||||
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -47,27 +38,12 @@ export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||
addImageUploadedFulfilledListener(startAppListening);
|
||||
|
||||
// Image deleted
|
||||
addImageDeletionListeners(startAppListening);
|
||||
addDeleteBoardAndImagesFulfilledListener(startAppListening);
|
||||
addImageToDeleteSelectedListener(startAppListening);
|
||||
|
||||
// Image starred
|
||||
addImagesStarredListener(startAppListening);
|
||||
addImagesUnstarredListener(startAppListening);
|
||||
|
||||
// Gallery
|
||||
addGalleryImageClickedListener(startAppListening);
|
||||
addGalleryOffsetChangedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
// Canvas actions
|
||||
addStagingListeners(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
matcher: matchAnyBoardDeleted,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const deletedBoardId = action.meta.arg.originalArgs;
|
||||
const deletedBoardId = action.meta.arg.originalArgs.board_id;
|
||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||
|
||||
// If the deleted board was currently selected, we should reset the selected board to uncategorized
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasReset, newSessionRequested } from 'features/controlLayers/store/actions';
|
||||
import { stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
const matchCanvasOrStagingAreaReset = isAnyOf(stagingAreaReset, canvasReset, newSessionRequested);
|
||||
|
||||
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: matchCanvasOrStagingAreaReset,
|
||||
effect: async (_, { dispatch }) => {
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.cancelByBatchDestination.initiate(
|
||||
{ destination: 'canvas' },
|
||||
{ fixedCacheKey: 'cancelByBatchOrigin' }
|
||||
)
|
||||
);
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
if (canceled > 0) {
|
||||
log.debug(`Canceled ${canceled} canvas batches`);
|
||||
toast({
|
||||
id: 'CANCEL_BATCH_SUCCEEDED',
|
||||
title: t('queue.cancelBatchSucceeded'),
|
||||
status: 'success',
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
log.error('Failed to cancel canvas batches');
|
||||
toast({
|
||||
id: 'CANCEL_BATCH_FAILED',
|
||||
title: t('queue.cancelBatchFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,15 +1,29 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const appStarted = createAction('app/appStarted');
|
||||
|
||||
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: appStarted,
|
||||
effect: (action, { unsubscribe, cancelActiveListeners }) => {
|
||||
effect: async (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
|
||||
// this should only run once
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
||||
// ensure an image is selected when we load the first board
|
||||
const firstImageLoad = await take(imagesApi.endpoints.getImageNames.matchFulfilled);
|
||||
if (firstImageLoad !== null) {
|
||||
const [{ payload }] = firstImageLoad;
|
||||
const selectedImage = selectLastSelectedImage(getState());
|
||||
if (selectedImage) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(payload.image_names.at(0) ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { truncate } from 'es-toolkit/compat';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { truncate } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||
import { getImageUsage } from 'features/deleteImageModal/store/state';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
@@ -20,9 +21,10 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
|
||||
const nodes = selectNodesSlice(state);
|
||||
const canvas = selectCanvasSlice(state);
|
||||
const upscale = selectUpscaleSlice(state);
|
||||
const refImages = selectRefImagesSlice(state);
|
||||
|
||||
deleted_images.forEach((image_name) => {
|
||||
const imageUsage = getImageUsage(nodes, canvas, upscale, image_name);
|
||||
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
|
||||
|
||||
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
||||
dispatch(nodeEditorReset());
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
@@ -11,36 +11,35 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
// Cancel any in-progress instances of this listener, we don't want to select an image from a previous board
|
||||
cancelActiveListeners();
|
||||
|
||||
if (boardIdSelected.match(action) && action.payload.selectedImageName) {
|
||||
// This action already has a selected image name, we trust it is valid
|
||||
return;
|
||||
}
|
||||
|
||||
const state = getState();
|
||||
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const board_id = selectSelectedBoardId(state);
|
||||
|
||||
const queryArgs = { ...selectGetImageNamesQueryArgs(state), board_id };
|
||||
|
||||
// wait until the board has some images - maybe it already has some from a previous fetch
|
||||
// must use getState() to ensure we do not have stale state
|
||||
const isSuccess = await condition(
|
||||
() => imagesApi.endpoints.listImages.select(queryArgs)(getState()).isSuccess,
|
||||
() => imagesApi.endpoints.getImageNames.select(queryArgs)(getState()).isSuccess,
|
||||
5000
|
||||
);
|
||||
|
||||
if (isSuccess) {
|
||||
// the board was just changed - we can select the first image
|
||||
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
|
||||
const selectedImage = boardImagesData.items.find(
|
||||
(item) => item.image_name === action.payload.selectedImageName
|
||||
);
|
||||
dispatch(imageSelected(selectedImage || null));
|
||||
} else if (boardImagesData) {
|
||||
dispatch(imageSelected(boardImagesData.items[0] || null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
} else {
|
||||
// fallback - deselect
|
||||
if (!isSuccess) {
|
||||
dispatch(imageSelected(null));
|
||||
return;
|
||||
}
|
||||
|
||||
// the board was just changed - we can select the first image
|
||||
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(getState()).data?.image_names;
|
||||
|
||||
const imageToSelect = imageNames?.at(0) ?? null;
|
||||
|
||||
dispatch(imageSelected(imageToSelect));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: enqueueRequestedCanvas,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
log.debug('Enqueue requested');
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const manager = $canvasManager.get();
|
||||
assert(manager, 'No canvas manager');
|
||||
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(state, manager);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(state, manager);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(state, manager);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(state, manager);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, manager);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, manager);
|
||||
case 'imagen4':
|
||||
return await buildImagen4Graph(state, manager);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, manager);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
|
||||
|
||||
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
|
||||
);
|
||||
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,44 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: enqueueRequestedUpscaling,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'upscaling',
|
||||
destination: 'gallery',
|
||||
});
|
||||
|
||||
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,73 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageDTO: ImageDTO;
|
||||
shiftKey: boolean;
|
||||
ctrlKey: boolean;
|
||||
metaKey: boolean;
|
||||
altKey: boolean;
|
||||
}>('gallery/imageClicked');
|
||||
|
||||
/**
|
||||
* This listener handles the logic for selecting images in the gallery.
|
||||
*
|
||||
* Previously, this logic was in a `useCallback` with the whole gallery selection as a dependency. Every time
|
||||
* the selection changed, the callback got recreated and all images rerendered. This could easily block for
|
||||
* hundreds of ms, more for lower end devices.
|
||||
*
|
||||
* Moving this logic into a listener means we don't need to recalculate anything dynamically and the gallery
|
||||
* is much more responsive.
|
||||
*/
|
||||
|
||||
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: galleryImageClicked,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (!queryResult.data) {
|
||||
// Should never happen if we have clicked a gallery image
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTOs = queryResult.data.items;
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
}
|
||||
} else if (shiftKey) {
|
||||
const rangeEndImageName = imageDTO.image_name;
|
||||
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
||||
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
||||
const currentClickedIndex = imageDTOs.findIndex((n) => n.image_name === rangeEndImageName);
|
||||
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
|
||||
// We have a valid range!
|
||||
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
||||
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
||||
const imagesToSelect = imageDTOs.slice(start, end + 1);
|
||||
dispatch(selectionChanged(selection.concat(imagesToSelect)));
|
||||
}
|
||||
} else if (ctrlKey || metaKey) {
|
||||
if (selection.some((i) => i.image_name === imageDTO.image_name) && selection.length > 1) {
|
||||
dispatch(selectionChanged(selection.filter((n) => n.image_name !== imageDTO.image_name)));
|
||||
} else {
|
||||
dispatch(selectionChanged(selection.concat(imageDTO)));
|
||||
}
|
||||
} else {
|
||||
dispatch(selectionChanged([imageDTO]));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,119 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, offsetChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addGalleryOffsetChangedListener = (startAppListening: AppStartListening) => {
|
||||
/**
|
||||
* When the user changes pages in the gallery, we need to wait until the next page of images is loaded, then maybe
|
||||
* update the selection.
|
||||
*
|
||||
* There are a three scenarios:
|
||||
*
|
||||
* 1. The page is changed by clicking the pagination buttons. No changes to selection are needed.
|
||||
*
|
||||
* 2. The page is changed by using the arrow keys (without alt).
|
||||
* - When going backwards, select the last image.
|
||||
* - When going forwards, select the first image.
|
||||
*
|
||||
* 3. The page is changed by using the arrows keys with alt. This means the user is changing the comparison image.
|
||||
* - When going backwards, select the last image _as the comparison image_.
|
||||
* - When going forwards, select the first image _as the comparison image_.
|
||||
*/
|
||||
startAppListening({
|
||||
actionCreator: offsetChanged,
|
||||
effect: async (action, { dispatch, getState, getOriginalState, take, cancelActiveListeners }) => {
|
||||
// Cancel any active listeners to prevent the selection from changing without user input
|
||||
cancelActiveListeners();
|
||||
|
||||
const { withHotkey } = action.payload;
|
||||
|
||||
if (!withHotkey) {
|
||||
// User changed pages by clicking the pagination buttons - no changes to selection
|
||||
return;
|
||||
}
|
||||
|
||||
const originalState = getOriginalState();
|
||||
const prevOffset = originalState.gallery.offset;
|
||||
const offset = getState().gallery.offset;
|
||||
|
||||
if (offset === prevOffset) {
|
||||
// The page didn't change - bail
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* We need to wait until the next page of images is loaded before updating the selection, so we use the correct
|
||||
* page of images.
|
||||
*
|
||||
* The simplest way to do it would be to use `take` to wait for the next fulfilled action, but RTK-Q doesn't
|
||||
* dispatch an action on cache hits. This means the `take` will only return if the cache is empty. If the user
|
||||
* changes to a cached page - a common situation - the `take` will never resolve.
|
||||
*
|
||||
* So we need to take a two-step approach. First, check if we have data in the cache for the page of images. If
|
||||
* we have data cached, use it to update the selection. If we don't have data cached, wait for the next fulfilled
|
||||
* action, which updates the cache, then use the cache to update the selection.
|
||||
*/
|
||||
|
||||
// Check if we have data in the cache for the page of images
|
||||
const queryArgs = selectListImagesQueryArgs(getState());
|
||||
let { data } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
// No data yet - wait for the network request to complete
|
||||
if (!data) {
|
||||
const takeResult = await take(imagesApi.endpoints.listImages.matchFulfilled, 5000);
|
||||
if (!takeResult) {
|
||||
// The request didn't complete in time - bail
|
||||
return;
|
||||
}
|
||||
data = takeResult[0].payload;
|
||||
}
|
||||
|
||||
// We awaited a network request - state could have changed, get fresh state
|
||||
const state = getState();
|
||||
const { selection, imageToCompare } = state.gallery;
|
||||
const imageDTOs = data?.items;
|
||||
|
||||
if (!imageDTOs) {
|
||||
// The page didn't load - bail
|
||||
return;
|
||||
}
|
||||
|
||||
if (withHotkey === 'arrow') {
|
||||
// User changed pages by using the arrow keys - selection changes to first or last image depending
|
||||
if (offset < prevOffset) {
|
||||
// We've gone backwards
|
||||
const lastImage = imageDTOs[imageDTOs.length - 1];
|
||||
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
|
||||
dispatch(selectionChanged(lastImage ? [lastImage] : []));
|
||||
}
|
||||
} else {
|
||||
// We've gone forwards
|
||||
const firstImage = imageDTOs[0];
|
||||
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
|
||||
dispatch(selectionChanged(firstImage ? [firstImage] : []));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (withHotkey === 'alt+arrow') {
|
||||
// User changed pages by using the arrow keys with alt - comparison image changes to first or last depending
|
||||
if (offset < prevOffset) {
|
||||
// We've gone backwards
|
||||
const lastImage = imageDTOs[imageDTOs.length - 1];
|
||||
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
|
||||
dispatch(imageToCompareChanged(lastImage));
|
||||
}
|
||||
} else {
|
||||
// We've gone forwards
|
||||
const firstImage = imageDTOs[0];
|
||||
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
|
||||
dispatch(imageToCompareChanged(firstImage));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,9 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { size } from 'es-toolkit/compat';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
@@ -8,16 +8,16 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
|
||||
effect: (action) => {
|
||||
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
||||
log.debug({ board_id, imageDTO }, 'Image added to board');
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
log.debug({ board_id, image_name }, 'Image added to board');
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
|
||||
effect: (action) => {
|
||||
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
||||
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
log.debug({ board_id, image_name }, 'Problem adding image to board');
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { entityDeleted, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { forEach, intersectionBy } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { Param0 } from 'tsafe';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
//TODO(psyche): handle image deletion (canvas staging area?)
|
||||
|
||||
// Some utils to delete images from different parts of the app
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
const actions: Param0<typeof dispatch>[] = [];
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name) {
|
||||
actions.push(
|
||||
fieldImageValueChanged({
|
||||
nodeId: node.data.id,
|
||||
fieldName: input.name,
|
||||
value: undefined,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (isImageFieldCollectionInputInstance(input)) {
|
||||
actions.push(
|
||||
fieldImageCollectionValueChanged({
|
||||
nodeId: node.data.id,
|
||||
fieldName: input.name,
|
||||
value: input.value?.filter((value) => value?.image_name !== imageDTO.image_name),
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
actions.forEach(dispatch);
|
||||
};
|
||||
|
||||
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
selectCanvasSlice(state).controlLayers.entities.forEach(({ id, objects }) => {
|
||||
let shouldDelete = false;
|
||||
for (const obj of objects) {
|
||||
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
|
||||
shouldDelete = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (shouldDelete) {
|
||||
dispatch(entityDeleted({ entityIdentifier: { id, type: 'control_layer' } }));
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
|
||||
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const deleteRasterLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
|
||||
let shouldDelete = false;
|
||||
for (const obj of objects) {
|
||||
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
|
||||
shouldDelete = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (shouldDelete) {
|
||||
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageDeletionListeners = (startAppListening: AppStartListening) => {
|
||||
// Handle single image deletion
|
||||
startAppListening({
|
||||
actionCreator: imageDeletionConfirmed,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTOs, imagesUsage } = action.payload;
|
||||
|
||||
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
|
||||
// handle multiples in separate listener
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTO = imageDTOs[0];
|
||||
const imageUsage = imagesUsage[0];
|
||||
|
||||
if (!imageDTO || !imageUsage) {
|
||||
// satisfy noUncheckedIndexedAccess
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const state = getState();
|
||||
await dispatch(imagesApi.endpoints.deleteImage.initiate(imageDTO)).unwrap();
|
||||
|
||||
if (state.gallery.selection.some((i) => i.image_name === imageDTO.image_name)) {
|
||||
// The deleted image was a selected image, we need to select the next image
|
||||
const newSelection = state.gallery.selection.filter((i) => i.image_name !== imageDTO.image_name);
|
||||
|
||||
if (newSelection.length > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the current list of images and select the same index
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const data = imagesApi.endpoints.listImages.select(baseQueryArgs)(state).data;
|
||||
|
||||
if (data) {
|
||||
const deletedImageIndex = data.items.findIndex((i) => i.image_name === imageDTO.image_name);
|
||||
const nextImage = data.items[deletedImageIndex + 1] ?? data.items[0] ?? null;
|
||||
if (nextImage?.image_name === imageDTO.image_name) {
|
||||
// If the next image is the same as the deleted one, it means it was the last image, reset selection
|
||||
dispatch(imageSelected(null));
|
||||
} else {
|
||||
dispatch(imageSelected(nextImage));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deleteNodesImages(state, dispatch, imageDTO);
|
||||
deleteReferenceImages(state, dispatch, imageDTO);
|
||||
deleteRasterLayerImages(state, dispatch, imageDTO);
|
||||
deleteControlLayerImages(state, dispatch, imageDTO);
|
||||
} catch {
|
||||
// no-op
|
||||
} finally {
|
||||
dispatch(isModalOpenChanged(false));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// Handle multiple image deletion
|
||||
startAppListening({
|
||||
actionCreator: imageDeletionConfirmed,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTOs, imagesUsage } = action.payload;
|
||||
|
||||
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
|
||||
// handle singles in separate listener
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const state = getState();
|
||||
await dispatch(imagesApi.endpoints.deleteImages.initiate({ imageDTOs })).unwrap();
|
||||
|
||||
if (intersectionBy(state.gallery.selection, imageDTOs, 'image_name').length > 0) {
|
||||
// Some selected images were deleted, need to select the next image
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
if (data) {
|
||||
// When we delete multiple images, we clear the selection. Then, the the next time we load images, we will
|
||||
// select the first one. This is handled below in the listener for `imagesApi.endpoints.listImages.matchFulfilled`.
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
|
||||
imageDTOs.forEach((imageDTO) => {
|
||||
deleteNodesImages(state, dispatch, imageDTO);
|
||||
deleteControlLayerImages(state, dispatch, imageDTO);
|
||||
deleteReferenceImages(state, dispatch, imageDTO);
|
||||
deleteRasterLayerImages(state, dispatch, imageDTO);
|
||||
});
|
||||
} catch {
|
||||
// no-op
|
||||
} finally {
|
||||
dispatch(isModalOpenChanged(false));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// When we list images, if no images is selected, select the first one.
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.listImages.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const selection = getState().gallery.selection;
|
||||
if (selection.length === 0) {
|
||||
dispatch(imageSelected(action.payload.items[0] ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
|
||||
effect: (action) => {
|
||||
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.deleteImage.matchRejected,
|
||||
effect: (action) => {
|
||||
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,32 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||
import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
|
||||
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: imagesToDeleteSelected,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const imageDTOs = action.payload;
|
||||
const state = getState();
|
||||
const { shouldConfirmOnDelete } = state.system;
|
||||
const imagesUsage = selectImageUsage(getState());
|
||||
|
||||
const isImageInUse =
|
||||
imagesUsage.some((i) => i.isRasterLayerImage) ||
|
||||
imagesUsage.some((i) => i.isControlLayerImage) ||
|
||||
imagesUsage.some((i) => i.isReferenceImage) ||
|
||||
imagesUsage.some((i) => i.isInpaintMaskImage) ||
|
||||
imagesUsage.some((i) => i.isUpscaleImage) ||
|
||||
imagesUsage.some((i) => i.isNodesImage) ||
|
||||
imagesUsage.some((i) => i.isRegionalGuidanceImage);
|
||||
|
||||
if (shouldConfirmOnDelete || isImageInUse) {
|
||||
dispatch(isModalOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -2,12 +2,12 @@ import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { omit } from 'es-toolkit/compat';
|
||||
import { imageUploadedClientSide } from 'features/gallery/store/actions';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { omit } from 'lodash-es';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.starImages.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { updated_image_names: starredImages } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
const { selection } = state.gallery;
|
||||
const updatedSelection: ImageDTO[] = [];
|
||||
|
||||
selection.forEach((selectedImageDTO) => {
|
||||
if (starredImages.includes(selectedImageDTO.image_name)) {
|
||||
updatedSelection.push({
|
||||
...selectedImageDTO,
|
||||
starred: true,
|
||||
});
|
||||
} else {
|
||||
updatedSelection.push(selectedImageDTO);
|
||||
}
|
||||
});
|
||||
dispatch(selectionChanged(updatedSelection));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,30 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { updated_image_names: unstarredImages } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
const { selection } = state.gallery;
|
||||
const updatedSelection: ImageDTO[] = [];
|
||||
|
||||
selection.forEach((selectedImageDTO) => {
|
||||
if (unstarredImages.includes(selectedImageDTO.image_name)) {
|
||||
updatedSelection.push({
|
||||
...selectedImageDTO,
|
||||
starred: false,
|
||||
});
|
||||
} else {
|
||||
updatedSelection.push(selectedImageDTO);
|
||||
}
|
||||
});
|
||||
dispatch(selectionChanged(updatedSelection));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,14 +1,28 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
|
||||
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
|
||||
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
|
||||
import {
|
||||
selectAllEntitiesOfType,
|
||||
selectBboxModelBase,
|
||||
selectCanvasSlice,
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isChatGPT4oModelConfig,
|
||||
isFluxKontextApiModelConfig,
|
||||
isFluxKontextModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
@@ -25,9 +39,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
|
||||
const newModel = result.data;
|
||||
|
||||
const newBaseModel = newModel.base;
|
||||
const didBaseModelChange = state.params.model?.base !== newBaseModel;
|
||||
const newBase = newModel.base;
|
||||
const didBaseModelChange = state.params.model?.base !== newBase;
|
||||
|
||||
if (didBaseModelChange) {
|
||||
// we may need to reset some incompatible submodels
|
||||
@@ -35,7 +48,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
|
||||
// handle incompatible loras
|
||||
state.loras.loras.forEach((lora) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
if (lora.model.base !== newBase) {
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
@@ -43,20 +56,82 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
|
||||
// handle incompatible vae
|
||||
const { vae } = state.params;
|
||||
if (vae && vae.base !== newBaseModel) {
|
||||
if (vae && vae.base !== newBase) {
|
||||
dispatch(vaeSelected(null));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// handle incompatible controlnets
|
||||
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
|
||||
// if (ca.model?.base !== newBaseModel) {
|
||||
// modelsCleared += 1;
|
||||
// if (ca.isEnabled) {
|
||||
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
|
||||
// to choose the best available model based on the new main model.
|
||||
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
|
||||
|
||||
let newGlobalRefImageModel = null;
|
||||
|
||||
// Certain models require the ref image model to be the same as the main model - others just need a matching
|
||||
// base. Helper to grab the first exact match or the first available model if no exact match is found.
|
||||
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
|
||||
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
|
||||
|
||||
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
|
||||
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
|
||||
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
|
||||
} else if (newModel.base === 'chatgpt-4o') {
|
||||
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
|
||||
} else if (newModel.base === 'flux-kontext') {
|
||||
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
|
||||
} else if (newModel.base === 'flux') {
|
||||
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
|
||||
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
|
||||
} else {
|
||||
newGlobalRefImageModel = allRefImageModels[0] ?? null;
|
||||
}
|
||||
|
||||
// All ref image entities are updated to use the same new model
|
||||
const refImageEntities = selectReferenceImageEntities(state);
|
||||
for (const entity of refImageEntities) {
|
||||
const shouldUpdateModel =
|
||||
(entity.config.model && entity.config.model.base !== newBase) ||
|
||||
(!entity.config.model && newGlobalRefImageModel);
|
||||
|
||||
if (shouldUpdateModel) {
|
||||
dispatch(
|
||||
refImageModelChanged({
|
||||
id: entity.id,
|
||||
modelConfig: newGlobalRefImageModel,
|
||||
})
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// For regional guidance, there is no smart logic - we just pick the first available model.
|
||||
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
|
||||
|
||||
// All regional guidance entities are updated to use the same new model.
|
||||
const canvasState = selectCanvasSlice(state);
|
||||
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
|
||||
for (const entity of canvasRegionalGuidanceEntities) {
|
||||
for (const refImage of entity.referenceImages) {
|
||||
// Only change the model if the current one is not compatible with the new base model.
|
||||
const shouldUpdateModel =
|
||||
(refImage.config.model && refImage.config.model.base !== newBase) ||
|
||||
(!refImage.config.model && newRegionalRefImageModel);
|
||||
|
||||
if (shouldUpdateModel) {
|
||||
dispatch(
|
||||
rgRefImageModelChanged({
|
||||
entityIdentifier: getEntityIdentifier(entity),
|
||||
referenceImageId: refImage.id,
|
||||
modelConfig: newRegionalRefImageModel,
|
||||
})
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
toast({
|
||||
@@ -71,9 +146,16 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
|
||||
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
|
||||
|
||||
const modelBase = selectBboxModelBase(state);
|
||||
if (!selectIsStaging(state) && modelBase !== state.params.model?.base) {
|
||||
dispatch(bboxSyncedToOptimalDimension());
|
||||
|
||||
if (modelBase !== state.params.model?.base) {
|
||||
// Sync generate tab settings whenever the model base changes
|
||||
dispatch(syncedToOptimalDimension());
|
||||
if (!selectIsStaging(state)) {
|
||||
// Canvas tab only syncs if not staging
|
||||
dispatch(bboxSyncedToOptimalDimension());
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import {
|
||||
controlLayerModelChanged,
|
||||
referenceImageIPAdapterModelChanged,
|
||||
rgIPAdapterModelChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { controlLayerModelChanged, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import {
|
||||
clipEmbedModelSelected,
|
||||
@@ -15,8 +11,9 @@ import {
|
||||
t5EncoderModelSelected,
|
||||
vaeSelected,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import {
|
||||
@@ -210,12 +207,12 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
|
||||
|
||||
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const ipaModels = models.filter(isIPAdapterModelConfig);
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.type !== 'ip_adapter') {
|
||||
selectRefImagesSlice(state).entities.forEach((entity) => {
|
||||
if (!isIPAdapterConfig(entity.config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedIPAdapterModel = entity.ipAdapter.model;
|
||||
const selectedIPAdapterModel = entity.config.model;
|
||||
// `null` is a valid IP adapter model - no need to do anything.
|
||||
if (!selectedIPAdapterModel) {
|
||||
return;
|
||||
@@ -225,16 +222,16 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
return;
|
||||
}
|
||||
log.debug({ selectedIPAdapterModel }, 'Selected IP adapter model is not available, clearing');
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
|
||||
});
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
|
||||
if (ipAdapter.type !== 'ip_adapter') {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
|
||||
if (!isIPAdapterConfig(config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedIPAdapterModel = ipAdapter.model;
|
||||
const selectedIPAdapterModel = config.model;
|
||||
// `null` is a valid IP adapter model - no need to do anything.
|
||||
if (!selectedIPAdapterModel) {
|
||||
return;
|
||||
@@ -245,7 +242,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
}
|
||||
log.debug({ selectedIPAdapterModel }, 'Selected IP adapter model is not available, clearing');
|
||||
dispatch(
|
||||
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
|
||||
rgRefImageModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -254,11 +251,11 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
|
||||
|
||||
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.type !== 'flux_redux') {
|
||||
selectRefImagesSlice(state).entities.forEach((entity) => {
|
||||
if (!isFLUXReduxConfig(entity.config)) {
|
||||
return;
|
||||
}
|
||||
const selectedFLUXReduxModel = entity.ipAdapter.model;
|
||||
const selectedFLUXReduxModel = entity.config.model;
|
||||
// `null` is a valid FLUX Redux model - no need to do anything.
|
||||
if (!selectedFLUXReduxModel) {
|
||||
return;
|
||||
@@ -268,16 +265,16 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
return;
|
||||
}
|
||||
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
|
||||
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||
dispatch(refImageModelChanged({ id: entity.id, modelConfig: null }));
|
||||
});
|
||||
|
||||
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
|
||||
if (ipAdapter.type !== 'flux_redux') {
|
||||
entity.referenceImages.forEach(({ id: referenceImageId, config }) => {
|
||||
if (!isFLUXReduxConfig(config)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedFLUXReduxModel = ipAdapter.model;
|
||||
const selectedFLUXReduxModel = config.model;
|
||||
// `null` is a valid FLUX Redux model - no need to do anything.
|
||||
if (!selectedFLUXReduxModel) {
|
||||
return;
|
||||
@@ -288,7 +285,7 @@ const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
}
|
||||
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
|
||||
dispatch(
|
||||
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
|
||||
rgRefImageModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isNil } from 'es-toolkit';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
heightChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setGuidance,
|
||||
@@ -9,6 +11,7 @@ import {
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
widthChanged,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
@@ -23,6 +26,7 @@ import {
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { t } from 'i18next';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
@@ -86,10 +90,16 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_rescale_multiplier) {
|
||||
if (!isNil(cfg_rescale_multiplier)) {
|
||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||
}
|
||||
} else {
|
||||
// Set this to 0 if it doesn't have a default. This value is
|
||||
// easy to miss in the UI when users are resetting defaults
|
||||
// and leaving it non-zero could lead to detrimental
|
||||
// effects.
|
||||
dispatch(setCfgRescaleMultiplier(0));
|
||||
}
|
||||
|
||||
if (steps) {
|
||||
@@ -106,15 +116,24 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
const setSizeOptions = { updateAspectRatio: true, clamp: true };
|
||||
|
||||
const isStaging = selectIsStaging(getState());
|
||||
if (!isStaging && width) {
|
||||
const activeTab = selectActiveTab(getState());
|
||||
if (activeTab === 'generate') {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
dispatch(widthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(heightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
|
||||
if (!isStaging && height) {
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
if (activeTab === 'canvas') {
|
||||
if (!isStaging) {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { objectEquals } from '@observ33r/object-equals';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { api } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
@@ -64,7 +64,7 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
|
||||
const nextQueueStatusData = await queueStatusRequest.unwrap();
|
||||
|
||||
// If the queue hasn't changed, we don't need to do anything.
|
||||
if (isEqual(prevQueueStatusData?.queue, nextQueueStatusData.queue)) {
|
||||
if (objectEquals(prevQueueStatusData?.queue, nextQueueStatusData.queue)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { atom } from 'nanostores';
|
||||
|
||||
@@ -32,11 +31,3 @@ export const getStore = () => {
|
||||
}
|
||||
return store;
|
||||
};
|
||||
|
||||
export const useAppStore = () => {
|
||||
const store = useStore($store);
|
||||
if (!store) {
|
||||
throw new ReduxStoreNotInitialized();
|
||||
}
|
||||
return store;
|
||||
};
|
||||
|
||||
@@ -11,5 +11,7 @@ export const $false: ReadableAtom<boolean> = atom(false);
|
||||
/**
|
||||
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
|
||||
* in a hook or component.
|
||||
*
|
||||
* @knipignore
|
||||
*/
|
||||
export const $true: ReadableAtom<boolean> = atom(true);
|
||||
|
||||
@@ -4,19 +4,19 @@ import { logger } from 'app/logging/logger';
|
||||
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
|
||||
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { keys, mergeWith, omit, pick } from 'es-toolkit/compat';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
canvasSessionSlice,
|
||||
canvasStagingAreaPersistConfig,
|
||||
canvasStagingAreaSlice,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
|
||||
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
|
||||
@@ -28,7 +28,6 @@ import { configSlice } from 'features/system/store/configSlice';
|
||||
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
|
||||
import { uiPersistConfig, uiSlice } from 'features/ui/store/uiSlice';
|
||||
import { diff } from 'jsondiffpatch';
|
||||
import { keys, mergeWith, omit, pick } from 'lodash-es';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
@@ -54,20 +53,19 @@ const allReducers = {
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
[uiSlice.name]: uiSlice.reducer,
|
||||
[dynamicPromptsSlice.name]: dynamicPromptsSlice.reducer,
|
||||
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
|
||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[upscaleSlice.name]: upscaleSlice.reducer,
|
||||
[stylePresetSlice.name]: stylePresetSlice.reducer,
|
||||
[paramsSlice.name]: paramsSlice.reducer,
|
||||
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
|
||||
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
|
||||
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
[workflowLibrarySlice.name]: workflowLibrarySlice.reducer,
|
||||
[refImagesSlice.name]: refImagesSlice.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@@ -103,7 +101,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[uiPersistConfig.name]: uiPersistConfig,
|
||||
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[canvasPersistConfig.name]: canvasPersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
[upscalePersistConfig.name]: upscalePersistConfig,
|
||||
@@ -113,6 +110,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
[workflowLibraryPersistConfig.name]: workflowLibraryPersistConfig,
|
||||
[refImagesSlice.name]: refImagesPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
@@ -175,6 +173,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.concat(authToastMiddleware)
|
||||
// .concat(getDebugLoggerMiddleware())
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
@@ -209,3 +208,4 @@ export type RootState = ReturnType<AppStore['getState']>;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
|
||||
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
|
||||
export type AppGetState = ReturnType<typeof createStore>['getState'];
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import type { AppStore, AppThunkDispatch, RootState } from 'app/store/store';
|
||||
import type { TypedUseSelectorHook } from 'react-redux';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||
export const useAppStore = () => useStore<RootState>();
|
||||
export const useAppStore = () => useStore.withTypes<AppStore>()();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { Selector } from '@reduxjs/toolkit';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
/**
|
||||
|
||||
@@ -14,6 +14,7 @@ export type AppFeature =
|
||||
| 'githubLink'
|
||||
| 'discordLink'
|
||||
| 'bugLink'
|
||||
| 'aboutModal'
|
||||
| 'localization'
|
||||
| 'consoleLogging'
|
||||
| 'dynamicPrompting'
|
||||
@@ -29,7 +30,8 @@ export type AppFeature =
|
||||
| 'hfToken'
|
||||
| 'retryQueueItem'
|
||||
| 'cancelAndClearAll'
|
||||
| 'chatGPT4oHigh';
|
||||
| 'chatGPT4oHigh'
|
||||
| 'modelRelationships';
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
@@ -76,6 +78,7 @@ export type AppConfig = {
|
||||
allowPrivateStylePresets: boolean;
|
||||
allowClientSideUpload: boolean;
|
||||
allowPublishWorkflows: boolean;
|
||||
allowPromptExpansion: boolean;
|
||||
disabledTabs: TabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
import { Box, type BoxProps, type SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { type FocusRegionName, useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { selectSystemShouldEnableHighlightFocusedRegions } from 'features/system/store/systemSlice';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
|
||||
interface FocusRegionWrapperProps extends BoxProps {
|
||||
region: FocusRegionName;
|
||||
focusOnMount?: boolean;
|
||||
}
|
||||
|
||||
const FOCUS_REGION_STYLES: SystemStyleObject = {
|
||||
position: 'relative',
|
||||
'&[data-highlighted="true"]::after': {
|
||||
borderColor: 'blue.700',
|
||||
},
|
||||
'&::after': {
|
||||
content: '""',
|
||||
position: 'absolute',
|
||||
inset: 0,
|
||||
zIndex: 1,
|
||||
borderRadius: 'base',
|
||||
border: '2px solid',
|
||||
borderColor: 'transparent',
|
||||
pointerEvents: 'none',
|
||||
transition: 'border-color 0.1s ease-in-out',
|
||||
},
|
||||
};
|
||||
|
||||
export const FocusRegionWrapper = memo(
|
||||
({ region, focusOnMount = false, sx, children, ...boxProps }: FocusRegionWrapperProps) => {
|
||||
const shouldHighlightFocusedRegions = useAppSelector(selectSystemShouldEnableHighlightFocusedRegions);
|
||||
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
|
||||
const options = useMemo(() => ({ focusOnMount }), [focusOnMount]);
|
||||
|
||||
useFocusRegion(region, ref, options);
|
||||
const isFocused = useIsRegionFocused(region);
|
||||
const isHighlighted = isFocused && shouldHighlightFocusedRegions;
|
||||
|
||||
return (
|
||||
<Box
|
||||
ref={ref}
|
||||
tabIndex={-1}
|
||||
sx={useMemo(() => ({ ...FOCUS_REGION_STYLES, ...sx }), [sx])}
|
||||
data-highlighted={isHighlighted}
|
||||
{...boxProps}
|
||||
>
|
||||
{children}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
FocusRegionWrapper.displayName = 'FocusRegionWrapper';
|
||||
@@ -15,9 +15,9 @@ import {
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { merge, omit } from 'es-toolkit/compat';
|
||||
import { selectSystemSlice, setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type { ReactElement } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -8,20 +8,16 @@ const Loading = () => {
|
||||
return (
|
||||
<Flex
|
||||
position="absolute"
|
||||
width="100dvw"
|
||||
height="100dvh"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
bg="#151519"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
bg="hsl(220 12% 10% / 1)" // base.900
|
||||
inset={0}
|
||||
zIndex={99999}
|
||||
>
|
||||
<Image src={InvokeLogoWhite} w="8rem" h="8rem" />
|
||||
<Spinner
|
||||
label="Loading"
|
||||
color="grey"
|
||||
color="hsl(220 12% 68% / 1)" // base.300
|
||||
position="absolute"
|
||||
size="sm"
|
||||
width="24px !important"
|
||||
|
||||
@@ -11,13 +11,14 @@ import { memo, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = PropsWithChildren & {
|
||||
maxHeight?: ChakraProps['maxHeight'];
|
||||
maxWidth?: ChakraProps['maxWidth'];
|
||||
overflowX?: 'hidden' | 'scroll';
|
||||
overflowY?: 'hidden' | 'scroll';
|
||||
};
|
||||
|
||||
const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0, bottom: 0 };
|
||||
|
||||
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
|
||||
const ScrollableContent = ({ children, maxHeight, maxWidth, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
|
||||
const overlayscrollbarsOptions = useMemo(
|
||||
() => getOverlayScrollbarsParams({ overflowX, overflowY }).options,
|
||||
[overflowX, overflowY]
|
||||
@@ -44,7 +45,7 @@ const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflow
|
||||
}, [os]);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" maxHeight={maxHeight} position="relative">
|
||||
<Flex w="full" h="full" maxHeight={maxHeight} maxWidth={maxWidth} position="relative">
|
||||
<Box position="absolute" top={0} left={0} right={0} bottom={0}>
|
||||
<OverlayScrollbarsComponent ref={osRef} style={styles} options={overlayscrollbarsOptions}>
|
||||
{children}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { merge } from 'lodash-es';
|
||||
import { merge } from 'es-toolkit/compat';
|
||||
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
|
||||
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.os-scrollbar {
|
||||
/* The size of the scrollbar */
|
||||
--os-size: 9px;
|
||||
--os-size: 8px;
|
||||
/* The axis-perpedicular padding of the scrollbar (horizontal: padding-y, vertical: padding-x) */
|
||||
/* --os-padding-perpendicular: 0; */
|
||||
/* The axis padding of the scrollbar (horizontal: padding-x, vertical: padding-y) */
|
||||
@@ -8,11 +8,11 @@
|
||||
/* The border radius of the scrollbar track */
|
||||
/* --os-track-border-radius: 0; */
|
||||
/* The background of the scrollbar track */
|
||||
/* --os-track-bg: rgba(0, 0, 0, 0.3); */
|
||||
--os-track-bg: rgba(0, 0, 0, 0.5);
|
||||
/* The :hover background of the scrollbar track */
|
||||
/* --os-track-bg-hover: rgba(0, 0, 0, 0.3); */
|
||||
--os-track-bg-hover: rgba(0, 0, 0, 0.5);
|
||||
/* The :active background of the scrollbar track */
|
||||
/* --os-track-bg-active: rgba(0, 0, 0, 0.3); */
|
||||
--os-track-bg-active: rgba(0, 0, 0, 0.6);
|
||||
/* The border of the scrollbar track */
|
||||
/* --os-track-border: none; */
|
||||
/* The :hover background of the scrollbar track */
|
||||
@@ -22,11 +22,11 @@
|
||||
/* The border radius of the scrollbar handle */
|
||||
/* --os-handle-border-radius: 2px; */
|
||||
/* The background of the scrollbar handle */
|
||||
/* --os-handle-bg: var(--invokeai-colors-accentAlpha-500); */
|
||||
--os-handle-bg: var(--invoke-colors-base-400);
|
||||
/* The :hover background of the scrollbar handle */
|
||||
/* --os-handle-bg-hover: var(--invokeai-colors-accentAlpha-700); */
|
||||
--os-handle-bg-hover: var(--invoke-colors-base-300);
|
||||
/* The :active background of the scrollbar handle */
|
||||
/* --os-handle-bg-active: var(--invokeai-colors-accentAlpha-800); */
|
||||
--os-handle-bg-active: var(--invoke-colors-base-250);
|
||||
/* The border of the scrollbar handle */
|
||||
/* --os-handle-border: none; */
|
||||
/* The :hover border of the scrollbar handle */
|
||||
@@ -34,23 +34,23 @@
|
||||
/* The :active border of the scrollbar handle */
|
||||
/* --os-handle-border-active: none; */
|
||||
/* The min size of the scrollbar handle */
|
||||
--os-handle-min-size: 50px;
|
||||
--os-handle-min-size: 32px;
|
||||
/* The max size of the scrollbar handle */
|
||||
/* --os-handle-max-size: none; */
|
||||
/* The axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
|
||||
/* --os-handle-perpendicular-size: 100%; */
|
||||
/* The :hover axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
|
||||
/* --os-handle-perpendicular-size-hover: 100%; */
|
||||
--os-handle-perpendicular-size-hover: 100%;
|
||||
/* The :active axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
|
||||
/* --os-handle-perpendicular-size-active: 100%; */
|
||||
/* Increases the interactive area of the scrollbar handle. */
|
||||
/* --os-handle-interactive-area-offset: 0; */
|
||||
--os-handle-interactive-area-offset: -1px;
|
||||
}
|
||||
|
||||
.os-scrollbar-handle {
|
||||
cursor: grab;
|
||||
/* cursor: grab; */
|
||||
}
|
||||
|
||||
.os-scrollbar-handle:active {
|
||||
cursor: grabbing;
|
||||
/* cursor: grabbing; */
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
|
||||
[uniqueGroupKey]: true,
|
||||
});
|
||||
|
||||
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
|
||||
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
|
||||
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
|
||||
};
|
||||
|
||||
@@ -198,6 +198,10 @@ type PickerProps<T extends object> = {
|
||||
* Whether the picker should be searchable. If true, renders a search input.
|
||||
*/
|
||||
searchable?: boolean;
|
||||
/**
|
||||
* Initial state for group toggles. If provided, groups will start with these states instead of all being disabled.
|
||||
*/
|
||||
initialGroupStates?: GroupStatusMap;
|
||||
};
|
||||
|
||||
export type PickerContextState<T extends object> = {
|
||||
@@ -310,9 +314,9 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
|
||||
return flattened;
|
||||
};
|
||||
|
||||
type GroupStatusMap = Record<string, boolean>;
|
||||
export type GroupStatusMap = Record<string, boolean>;
|
||||
|
||||
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], initialGroupStates?: GroupStatusMap) => {
|
||||
const groupsWithOptions = useMemo(() => {
|
||||
const ids: string[] = [];
|
||||
for (const optionOrGroup of options) {
|
||||
@@ -332,14 +336,16 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
const groupStatusMap = $groupStatusMap.get();
|
||||
const newMap: GroupStatusMap = {};
|
||||
for (const id of groupsWithOptions) {
|
||||
if (newMap[id] === undefined) {
|
||||
newMap[id] = false;
|
||||
if (initialGroupStates && initialGroupStates[id] !== undefined) {
|
||||
newMap[id] = initialGroupStates[id];
|
||||
} else if (groupStatusMap[id] !== undefined) {
|
||||
newMap[id] = groupStatusMap[id];
|
||||
} else {
|
||||
newMap[id] = false;
|
||||
}
|
||||
}
|
||||
$groupStatusMap.set(newMap);
|
||||
}, [groupsWithOptions, $groupStatusMap]);
|
||||
}, [groupsWithOptions, $groupStatusMap, initialGroupStates]);
|
||||
|
||||
const toggleGroup = useCallback(
|
||||
(idToToggle: string) => {
|
||||
@@ -511,10 +517,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
OptionComponent = DefaultOptionComponent,
|
||||
NextToSearchBar,
|
||||
searchable,
|
||||
initialGroupStates,
|
||||
} = props;
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
|
||||
optionsOrGroups,
|
||||
initialGroupStates
|
||||
);
|
||||
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
|
||||
const $compactView = useAtom(true);
|
||||
const $optionsOrGroups = useAtom(optionsOrGroups);
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
useNewCanvasSession,
|
||||
useNewGallerySession,
|
||||
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
|
||||
import { allEntitiesDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiFilePlusBold } from 'react-icons/pi';
|
||||
import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
export const SessionMenuItems = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { newGallerySessionWithDialog } = useNewGallerySession();
|
||||
const { newCanvasSessionWithDialog } = useNewCanvasSession();
|
||||
|
||||
const resetCanvasLayers = useCallback(() => {
|
||||
dispatch(allEntitiesDeleted());
|
||||
}, [dispatch]);
|
||||
@@ -23,12 +18,6 @@ export const SessionMenuItems = memo(() => {
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<PiFilePlusBold />} onClick={newGallerySessionWithDialog}>
|
||||
{t('controlLayers.newGallerySession')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiFilePlusBold />} onClick={newCanvasSessionWithDialog}>
|
||||
{t('controlLayers.newCanvasSession')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
|
||||
{t('controlLayers.resetCanvasLayers')}
|
||||
</MenuItem>
|
||||
|
||||
@@ -6,6 +6,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
import { objectKeys } from 'tsafe';
|
||||
import z from 'zod/v4';
|
||||
|
||||
/**
|
||||
* We need to manage focus regions to conditionally enable hotkeys:
|
||||
@@ -30,23 +31,34 @@ const log = logger('system');
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
export type FocusRegionName = 'gallery' | 'layers' | 'canvas' | 'workflows' | 'viewer';
|
||||
const zFocusRegionName = z.enum([
|
||||
'launchpad',
|
||||
'viewer',
|
||||
'gallery',
|
||||
'boards',
|
||||
'layers',
|
||||
'canvas',
|
||||
'workflows',
|
||||
'progress',
|
||||
'settings',
|
||||
]);
|
||||
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
|
||||
|
||||
/**
|
||||
* A map of focus regions to the elements that are part of that region.
|
||||
*/
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = {
|
||||
gallery: new Set<HTMLElement>(),
|
||||
layers: new Set<HTMLElement>(),
|
||||
canvas: new Set<HTMLElement>(),
|
||||
workflows: new Set<HTMLElement>(),
|
||||
viewer: new Set<HTMLElement>(),
|
||||
} as const;
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
|
||||
(acc, region) => {
|
||||
acc[region] = new Set<HTMLElement>();
|
||||
return acc;
|
||||
},
|
||||
{} as Record<FocusRegionName, Set<HTMLElement>>
|
||||
);
|
||||
|
||||
/**
|
||||
* The currently-focused region or `null` if no region is focused.
|
||||
*/
|
||||
export const $focusedRegion = atom<FocusRegionName | null>(null);
|
||||
const $focusedRegion = atom<FocusRegionName | null>(null);
|
||||
|
||||
/**
|
||||
* A map of focus regions to atoms that indicate if that region is focused.
|
||||
@@ -62,11 +74,13 @@ const FOCUS_REGIONS = objectKeys(REGION_TARGETS).reduce(
|
||||
/**
|
||||
* Sets the focused region, logging a trace level message.
|
||||
*/
|
||||
const setFocus = (region: FocusRegionName | null) => {
|
||||
export const setFocusedRegion = (region: FocusRegionName | null) => {
|
||||
$focusedRegion.set(region);
|
||||
log.trace(`Focus changed: ${region}`);
|
||||
};
|
||||
|
||||
export const getFocusedRegion = () => $focusedRegion.get();
|
||||
|
||||
type UseFocusRegionOptions = {
|
||||
focusOnMount?: boolean;
|
||||
};
|
||||
@@ -99,14 +113,14 @@ export const useFocusRegion = (
|
||||
REGION_TARGETS[region].add(element);
|
||||
|
||||
if (focusOnMount) {
|
||||
setFocus(region);
|
||||
setFocusedRegion(region);
|
||||
}
|
||||
|
||||
return () => {
|
||||
REGION_TARGETS[region].delete(element);
|
||||
|
||||
if (REGION_TARGETS[region].size === 0 && $focusedRegion.get() === region) {
|
||||
setFocus(null);
|
||||
setFocusedRegion(null);
|
||||
}
|
||||
};
|
||||
}, [options, ref, region]);
|
||||
@@ -163,7 +177,7 @@ const onFocus = (_: FocusEvent) => {
|
||||
return;
|
||||
}
|
||||
|
||||
setFocus(focusedRegion);
|
||||
setFocusedRegion(focusedRegion);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
115
invokeai/frontend/web/src/common/hooks/useAsyncState.ts
Normal file
115
invokeai/frontend/web/src/common/hooks/useAsyncState.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { WrappedError } from 'common/util/result';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type SuccessState<T> = {
|
||||
status: 'success';
|
||||
value: T;
|
||||
error: null;
|
||||
};
|
||||
|
||||
type ErrorState = {
|
||||
status: 'error';
|
||||
value: null;
|
||||
error: Error;
|
||||
};
|
||||
|
||||
type PendingState = {
|
||||
status: 'pending';
|
||||
value: null;
|
||||
error: null;
|
||||
};
|
||||
|
||||
type IdleState = {
|
||||
status: 'idle';
|
||||
value: null;
|
||||
error: null;
|
||||
};
|
||||
|
||||
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
|
||||
|
||||
type UseAsyncStateOptions = {
|
||||
immediate?: boolean;
|
||||
};
|
||||
|
||||
type UseAsyncReturn<T> = {
|
||||
$state: Atom<State<T>>;
|
||||
trigger: () => Promise<void>;
|
||||
reset: () => void;
|
||||
};
|
||||
|
||||
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
|
||||
const $state = useState(() =>
|
||||
atom<State<T>>({
|
||||
status: 'idle',
|
||||
value: null,
|
||||
error: null,
|
||||
})
|
||||
)[0];
|
||||
|
||||
const trigger = useCallback(async () => {
|
||||
$state.set({
|
||||
status: 'pending',
|
||||
value: null,
|
||||
error: null,
|
||||
});
|
||||
try {
|
||||
const value = await execute();
|
||||
$state.set({
|
||||
status: 'success',
|
||||
value,
|
||||
error: null,
|
||||
});
|
||||
} catch (error) {
|
||||
$state.set({
|
||||
status: 'error',
|
||||
value: null,
|
||||
error: WrappedError.wrap(error),
|
||||
});
|
||||
}
|
||||
}, [$state, execute]);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
$state.set({
|
||||
status: 'idle',
|
||||
value: null,
|
||||
error: null,
|
||||
});
|
||||
}, [$state]);
|
||||
|
||||
useEffect(() => {
|
||||
if (options?.immediate) {
|
||||
trigger();
|
||||
}
|
||||
}, [options?.immediate, trigger]);
|
||||
|
||||
const api = useMemo(
|
||||
() =>
|
||||
({
|
||||
$state,
|
||||
trigger,
|
||||
reset,
|
||||
}) satisfies UseAsyncReturn<T>,
|
||||
[$state, trigger, reset]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
|
||||
type UseAsyncReturnReactive<T> = {
|
||||
state: State<T>;
|
||||
trigger: () => Promise<void>;
|
||||
reset: () => void;
|
||||
};
|
||||
|
||||
export const useAsyncStateReactive = <T>(
|
||||
execute: () => Promise<T>,
|
||||
options?: UseAsyncStateOptions
|
||||
): UseAsyncReturnReactive<T> => {
|
||||
const { $state, trigger, reset } = useAsyncState(execute, options);
|
||||
const state = useStore($state);
|
||||
|
||||
return { state, trigger, reset };
|
||||
};
|
||||
@@ -0,0 +1,29 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import { dropTargetForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import { dropTargetForExternal } from '@atlaskit/pragmatic-drag-and-drop/external/adapter';
|
||||
import { useTimeoutCallback } from 'common/hooks/useTimeoutCallback';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
|
||||
export const useCallbackOnDragEnter = (cb: () => void, ref: RefObject<HTMLElement>, delay = 300) => {
|
||||
const [run, cancel] = useTimeoutCallback(cb, delay);
|
||||
|
||||
useEffect(() => {
|
||||
const element = ref.current;
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
return combine(
|
||||
dropTargetForElements({
|
||||
element,
|
||||
onDragEnter: run,
|
||||
onDragLeave: cancel,
|
||||
}),
|
||||
dropTargetForExternal({
|
||||
element,
|
||||
onDragEnter: run,
|
||||
onDragLeave: cancel,
|
||||
})
|
||||
);
|
||||
}, [cancel, ref, run]);
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user