mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 08:28:14 -05:00
Compare commits
985 Commits
v5.10.0rc1
...
v6.0.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d67ee6548 | ||
|
|
03c21d1607 | ||
|
|
752e8db1f5 | ||
|
|
85fc861dd9 | ||
|
|
458cbfd874 | ||
|
|
04331c070a | ||
|
|
632ddf0cb4 | ||
|
|
2b193ff416 | ||
|
|
96ee394f9e | ||
|
|
0badc80c0c | ||
|
|
78e6cbf96e | ||
|
|
0b969a661b | ||
|
|
6fe47ec9f8 | ||
|
|
3850dd61f8 | ||
|
|
75520eaf0f | ||
|
|
10e88c58c1 | ||
|
|
30ed4dbd92 | ||
|
|
ed9c090f33 | ||
|
|
d29f65ed22 | ||
|
|
2062ec8ac0 | ||
|
|
49e818338a | ||
|
|
1caab2b9c4 | ||
|
|
50079ea349 | ||
|
|
fffa1b24c4 | ||
|
|
a6d6170387 | ||
|
|
e5fceb0448 | ||
|
|
059baf5b29 | ||
|
|
1be8a9a310 | ||
|
|
7adc33e04d | ||
|
|
7f2dd22d47 | ||
|
|
bb50f4b8a2 | ||
|
|
a48958e0d4 | ||
|
|
e3a1e9af53 | ||
|
|
c6fe11c42f | ||
|
|
4eb1bd67df | ||
|
|
c376f914d2 | ||
|
|
b5d1c47ef7 | ||
|
|
004a52ca65 | ||
|
|
b1d5a51ddf | ||
|
|
2b2498eaa1 | ||
|
|
10dda4440e | ||
|
|
98f78abefa | ||
|
|
cc93fa270f | ||
|
|
014b27680f | ||
|
|
c3d8f875de | ||
|
|
79f9dc6e4a | ||
|
|
6e1c0c1105 | ||
|
|
0362524040 | ||
|
|
dc6656459b | ||
|
|
3ea1b97f6f | ||
|
|
a7c7405ccc | ||
|
|
c391f1117a | ||
|
|
b1e2cb8401 | ||
|
|
db6af134b7 | ||
|
|
7e6cffb00c | ||
|
|
5b187bcb00 | ||
|
|
0843d609a3 | ||
|
|
95bd9cef18 | ||
|
|
931d6521f6 | ||
|
|
e37665ff59 | ||
|
|
56857fbbe6 | ||
|
|
43cfb8a574 | ||
|
|
05b1682d15 | ||
|
|
69a08ee7f2 | ||
|
|
18212c7d8a | ||
|
|
7de26f8e69 | ||
|
|
0652b12a6f | ||
|
|
43a361a00f | ||
|
|
cf68ad9cbc | ||
|
|
ec02a39325 | ||
|
|
e52d7a05c2 | ||
|
|
c9d4e2b761 | ||
|
|
ac26aa9508 | ||
|
|
9ff6ada15b | ||
|
|
e81a115169 | ||
|
|
52827807de | ||
|
|
b631de4cb5 | ||
|
|
099ebdbc37 | ||
|
|
4de6549be9 | ||
|
|
368be34949 | ||
|
|
5baa4bd916 | ||
|
|
4229377532 | ||
|
|
2610772ffd | ||
|
|
193de6a8f2 | ||
|
|
7ea343c787 | ||
|
|
12179dabba | ||
|
|
ef135f9923 | ||
|
|
e6c67cc00f | ||
|
|
179b988148 | ||
|
|
d913a3c85b | ||
|
|
e79525c40c | ||
|
|
f409f913ac | ||
|
|
7a79f61d4c | ||
|
|
ea182c234b | ||
|
|
f2eee4a82d | ||
|
|
e129525306 | ||
|
|
ecedfce758 | ||
|
|
702cb2cb1e | ||
|
|
2e8db3cce3 | ||
|
|
7845623fa5 | ||
|
|
e6a25ca7a2 | ||
|
|
71e12bcebe | ||
|
|
863c7eb9e2 | ||
|
|
9945c20d02 | ||
|
|
e3c1334b1f | ||
|
|
c143f63ef0 | ||
|
|
067026a0d0 | ||
|
|
66991334fc | ||
|
|
b771c3b164 | ||
|
|
4925694dc1 | ||
|
|
0a737ced44 | ||
|
|
8d83caaae0 | ||
|
|
16c8017f1a | ||
|
|
61a35f1396 | ||
|
|
6bd004d868 | ||
|
|
b6a6d406c7 | ||
|
|
8e287c32ee | ||
|
|
2d8b5e26c2 | ||
|
|
50914b74ee | ||
|
|
0fc1c33536 | ||
|
|
3b08c35f72 | ||
|
|
607b2561fd | ||
|
|
d68f922efb | ||
|
|
2bbd74d418 | ||
|
|
3a5392a9ee | ||
|
|
6f80efe71d | ||
|
|
7fac833813 | ||
|
|
b67eb4134d | ||
|
|
522eeda2e2 | ||
|
|
76233241f0 | ||
|
|
54be9989c5 | ||
|
|
0d3af08d27 | ||
|
|
767ac91f2c | ||
|
|
68571ece8f | ||
|
|
01100a2b9a | ||
|
|
ce2e6d8ab6 | ||
|
|
4887424ca3 | ||
|
|
28f6a20e71 | ||
|
|
c4142e75b2 | ||
|
|
fefe563127 | ||
|
|
1c72f1ff9f | ||
|
|
605cc7369d | ||
|
|
e7ce08cffa | ||
|
|
983cb5ebd2 | ||
|
|
52dbdb7118 | ||
|
|
71e6f00e10 | ||
|
|
e73150c3e6 | ||
|
|
f2426c3ab2 | ||
|
|
9d9c4c0f1a | ||
|
|
acb930f6b9 | ||
|
|
585b54dc7d | ||
|
|
f65affc0ec | ||
|
|
22d574c92a | ||
|
|
f23be119fc | ||
|
|
2d06949e80 | ||
|
|
67804313e1 | ||
|
|
dc23be117a | ||
|
|
350de058fc | ||
|
|
fd5cd707a3 | ||
|
|
98ecefdce0 | ||
|
|
42688a0993 | ||
|
|
d94aa4abf7 | ||
|
|
69a56aafed | ||
|
|
56873f6936 | ||
|
|
6bc6a680cf | ||
|
|
9a49682f60 | ||
|
|
ff84b0a495 | ||
|
|
bcced8a5e8 | ||
|
|
4a18e9eaea | ||
|
|
dde5bf61be | ||
|
|
987e401709 | ||
|
|
5c5ac570e3 | ||
|
|
309903fe0f | ||
|
|
f16ea43e9a | ||
|
|
d794aedb43 | ||
|
|
9930440f33 | ||
|
|
f0a6c4aa1f | ||
|
|
f36d22f13c | ||
|
|
e0d7fab524 | ||
|
|
f20c230f4a | ||
|
|
05c9bc730e | ||
|
|
f17ac06591 | ||
|
|
b35f93d919 | ||
|
|
289d8076d8 | ||
|
|
604763d20f | ||
|
|
7b452f098d | ||
|
|
b41c18d35f | ||
|
|
8328081333 | ||
|
|
07517cf2c2 | ||
|
|
6b98ad9095 | ||
|
|
0de3967e7e | ||
|
|
1335377fb1 | ||
|
|
adbcc191d9 | ||
|
|
11fc7af1c8 | ||
|
|
6f12fd22b9 | ||
|
|
324b6e2af4 | ||
|
|
038010a1ca | ||
|
|
2dd1bc54c9 | ||
|
|
8b69842678 | ||
|
|
9821f7c4fc | ||
|
|
2290ff4ad6 | ||
|
|
8d82ad6d0b | ||
|
|
8ed9f652e8 | ||
|
|
ee8ed344bd | ||
|
|
6d16cfdbe2 | ||
|
|
3ef2872dda | ||
|
|
b52ba149b4 | ||
|
|
c6126c6875 | ||
|
|
3f78ac9295 | ||
|
|
79fea1ac40 | ||
|
|
6eade5781d | ||
|
|
3d8f865fb0 | ||
|
|
dc9cd22d9d | ||
|
|
fe115ff8f9 | ||
|
|
1d35aad213 | ||
|
|
195d6ce893 | ||
|
|
f13ced7ed4 | ||
|
|
735fc276e5 | ||
|
|
cd3caf8c30 | ||
|
|
e9012280ab | ||
|
|
fa72a97794 | ||
|
|
e817631ba3 | ||
|
|
d0619c033f | ||
|
|
6f4850f34f | ||
|
|
072cd9dee7 | ||
|
|
19b6dc1c1f | ||
|
|
7566d0d6c6 | ||
|
|
f123888b46 | ||
|
|
aeab7d0cab | ||
|
|
3f1b2c39ab | ||
|
|
72e3a4b4be | ||
|
|
58e0f80138 | ||
|
|
8b8e29d22d | ||
|
|
90201be670 | ||
|
|
46a5619100 | ||
|
|
d608a7469e | ||
|
|
a7d413d372 | ||
|
|
f5c9e68dbf | ||
|
|
1ded459f03 | ||
|
|
d9024dc230 | ||
|
|
40528692c3 | ||
|
|
f35b05be43 | ||
|
|
29e87fc615 | ||
|
|
ca26b2718e | ||
|
|
5fa6c0b413 | ||
|
|
c37c8c50cd | ||
|
|
f0a4de245d | ||
|
|
5db62f8643 | ||
|
|
e1c478f94c | ||
|
|
11fe3b6332 | ||
|
|
e4aae1a591 | ||
|
|
4d83d1c56d | ||
|
|
34def323e8 | ||
|
|
854956316b | ||
|
|
91afe7884a | ||
|
|
8417ee8a7b | ||
|
|
a035645ed3 | ||
|
|
e00ccba7d3 | ||
|
|
fb883d63aa | ||
|
|
b113c57fc4 | ||
|
|
7636007349 | ||
|
|
fda86ae981 | ||
|
|
c02be4bdf4 | ||
|
|
ed7772d993 | ||
|
|
baae998b5b | ||
|
|
4077ffe595 | ||
|
|
c1937b1379 | ||
|
|
5c66dfed8e | ||
|
|
126dcc96c0 | ||
|
|
cb9c7b4a28 | ||
|
|
e8c4f49a14 | ||
|
|
30fffae637 | ||
|
|
4558a292b6 | ||
|
|
825d17441c | ||
|
|
9b16504af9 | ||
|
|
46c92fadff | ||
|
|
c0467b82ac | ||
|
|
6dafa67286 | ||
|
|
eb406aa07e | ||
|
|
d9422ffebd | ||
|
|
d5c033be4d | ||
|
|
4662cd6f15 | ||
|
|
a740a22613 | ||
|
|
bf4016b4bc | ||
|
|
6fa7c8c2ee | ||
|
|
ea40f582da | ||
|
|
01caf56251 | ||
|
|
42d577e65a | ||
|
|
38d80c9ce5 | ||
|
|
6acaa8abbf | ||
|
|
4b84e34599 | ||
|
|
bbd21b1eb2 | ||
|
|
4fa83a6228 | ||
|
|
051876dcff | ||
|
|
8dc6d0b5ae | ||
|
|
40e9624954 | ||
|
|
ae27c83dc4 | ||
|
|
161059551b | ||
|
|
c196f8a5d5 | ||
|
|
2c6d22664e | ||
|
|
b9ce5389ef | ||
|
|
d1cbf56695 | ||
|
|
e379ac12c3 | ||
|
|
aa10373292 | ||
|
|
780f3692a0 | ||
|
|
3604dcfdd1 | ||
|
|
2b1cffde5e | ||
|
|
83d642ed15 | ||
|
|
455c73235e | ||
|
|
8efef8da41 | ||
|
|
060a9e57b9 | ||
|
|
099d75ca1e | ||
|
|
bbb5d68146 | ||
|
|
9066dc1839 | ||
|
|
075345bffd | ||
|
|
74d1239c87 | ||
|
|
51e1c56636 | ||
|
|
ca1df60e54 | ||
|
|
7549c1250d | ||
|
|
df8751b5a1 | ||
|
|
651b80b997 | ||
|
|
5d236ae4e7 | ||
|
|
e5dc606f5e | ||
|
|
dc6b8e13bd | ||
|
|
c1b34e1f11 | ||
|
|
89f1684072 | ||
|
|
14fbee17a3 | ||
|
|
5dbc32e06e | ||
|
|
23baf61e51 | ||
|
|
5e55f6074b | ||
|
|
f7c555e501 | ||
|
|
6aa605e811 | ||
|
|
f51014e108 | ||
|
|
9862ba9210 | ||
|
|
920aea08cc | ||
|
|
39e584297e | ||
|
|
62a14bb935 | ||
|
|
d7ae2cdf75 | ||
|
|
6172c859ac | ||
|
|
b26fb1f617 | ||
|
|
05167dfd7a | ||
|
|
c090ea7387 | ||
|
|
7ba6c67049 | ||
|
|
3de186061d | ||
|
|
a716381733 | ||
|
|
fb5df06835 | ||
|
|
33c597c224 | ||
|
|
19d882d038 | ||
|
|
ee4bc49bd4 | ||
|
|
188cf37f48 | ||
|
|
15a0a7134c | ||
|
|
22cea0de8b | ||
|
|
cd21816d12 | ||
|
|
605b912ba4 | ||
|
|
52e31112f9 | ||
|
|
a4c9346cd7 | ||
|
|
a1647e4c6e | ||
|
|
8c9ca088a7 | ||
|
|
7a7a2e147c | ||
|
|
adf4cc750a | ||
|
|
9f1ea9d1c7 | ||
|
|
571d286506 | ||
|
|
1320a2c5f8 | ||
|
|
26a9b3131d | ||
|
|
d48140b35d | ||
|
|
9757bb0325 | ||
|
|
38ccd8e09c | ||
|
|
7759b166a9 | ||
|
|
9fc51c7a6e | ||
|
|
62fa4f42f5 | ||
|
|
418ad0de38 | ||
|
|
f4a411326e | ||
|
|
6358f39ebb | ||
|
|
ea8da0bfbf | ||
|
|
5385282325 | ||
|
|
0bf84ab803 | ||
|
|
82f31f2258 | ||
|
|
966dd8857d | ||
|
|
1c778bd719 | ||
|
|
394a14cf61 | ||
|
|
0e843823d1 | ||
|
|
29462e62d2 | ||
|
|
175c0147f8 | ||
|
|
df6e67c982 | ||
|
|
4612f0ac50 | ||
|
|
386a932f2a | ||
|
|
32438532b0 | ||
|
|
ab5cb2c264 | ||
|
|
504daa0ae5 | ||
|
|
14f7c98e8a | ||
|
|
ab39305223 | ||
|
|
7948bca864 | ||
|
|
1a39d22b6c | ||
|
|
9424271d12 | ||
|
|
b5acc204a8 | ||
|
|
7aefa8f36b | ||
|
|
242da9e888 | ||
|
|
1aedc26041 | ||
|
|
2c7fa90892 | ||
|
|
6c8cf99ad2 | ||
|
|
a92ba2542c | ||
|
|
2367b9f945 | ||
|
|
a928ed0204 | ||
|
|
e164451dfe | ||
|
|
d74d079356 | ||
|
|
0eb4360c01 | ||
|
|
937c03f2ec | ||
|
|
f7b249252d | ||
|
|
b2b42be51c | ||
|
|
98368b0665 | ||
|
|
b5eb3d9798 | ||
|
|
1218f49e20 | ||
|
|
89c609fd61 | ||
|
|
b204fb6a91 | ||
|
|
6e3e316416 | ||
|
|
bf5fc9512d | ||
|
|
7080889ed4 | ||
|
|
adea983bfc | ||
|
|
f68d8ed36a | ||
|
|
d45197e0af | ||
|
|
434d8a2b12 | ||
|
|
f55c593705 | ||
|
|
8327d86774 | ||
|
|
c8254710e6 | ||
|
|
0a8f647260 | ||
|
|
32a5e9652a | ||
|
|
87909a06a8 | ||
|
|
2c8ce6f2f4 | ||
|
|
bee4cf41b4 | ||
|
|
049a8d8144 | ||
|
|
ac81ec41c3 | ||
|
|
a294e8e0fd | ||
|
|
4665f0df40 | ||
|
|
70382294f5 | ||
|
|
4028cadfaf | ||
|
|
d23cdfd0ad | ||
|
|
f0ba693922 | ||
|
|
214005d795 | ||
|
|
34aa131115 | ||
|
|
5d8061bea9 | ||
|
|
36ec1015d6 | ||
|
|
7208373576 | ||
|
|
e10afe3026 | ||
|
|
399d6e7bce | ||
|
|
8d0fe5522b | ||
|
|
81341deb46 | ||
|
|
a30933b09c | ||
|
|
3264188ffd | ||
|
|
3984b341e1 | ||
|
|
041023df53 | ||
|
|
b06f76cdb6 | ||
|
|
852badc90b | ||
|
|
01953cf057 | ||
|
|
241844bdef | ||
|
|
33a28ad4f9 | ||
|
|
7c4550cbd5 | ||
|
|
553d1a6ac6 | ||
|
|
f4794e409b | ||
|
|
df87800d61 | ||
|
|
16993cd216 | ||
|
|
7f222ffb9d | ||
|
|
e0ed56ff8d | ||
|
|
e7e1142c77 | ||
|
|
fcaeba290e | ||
|
|
6eecdca56c | ||
|
|
7f44da4902 | ||
|
|
abaa33e22c | ||
|
|
d5c238e7c2 | ||
|
|
18775e8b67 | ||
|
|
903776bfbc | ||
|
|
a5baf0c102 | ||
|
|
a7e45731ec | ||
|
|
32aa3e6d48 | ||
|
|
2f9ea91896 | ||
|
|
5ac5115269 | ||
|
|
161624c722 | ||
|
|
c31cb0b106 | ||
|
|
893f7a8744 | ||
|
|
2e0824a799 | ||
|
|
ed05bf2df3 | ||
|
|
0f1a69a0c3 | ||
|
|
450a0bf142 | ||
|
|
a28c15d545 | ||
|
|
1b1e1983d9 | ||
|
|
d08e2fbd82 | ||
|
|
45b1ef6231 | ||
|
|
3bb446c08f | ||
|
|
8d1ab0a2e5 | ||
|
|
48e2e7e4a1 | ||
|
|
5a2f5c105d | ||
|
|
aa93e95a94 | ||
|
|
a5e5cbd7c3 | ||
|
|
baa9141be3 | ||
|
|
c7ed351bab | ||
|
|
8c17bde4ea | ||
|
|
ba082ccc2f | ||
|
|
01784fb3bf | ||
|
|
a71a0e143c | ||
|
|
94afc13813 | ||
|
|
d640a9001b | ||
|
|
711fe91b24 | ||
|
|
2f26657c17 | ||
|
|
6754fde935 | ||
|
|
ac206f4767 | ||
|
|
c316f07fb2 | ||
|
|
e81dde0933 | ||
|
|
9f392c8c3c | ||
|
|
2531366386 | ||
|
|
9df69496e4 | ||
|
|
2ddcde13ff | ||
|
|
cc5083599d | ||
|
|
2431060a7e | ||
|
|
592c842632 | ||
|
|
bc3550f238 | ||
|
|
23511d68db | ||
|
|
cd0668dd0b | ||
|
|
bf5ed61b84 | ||
|
|
3038a797a6 | ||
|
|
9bbc31b2d9 | ||
|
|
526e6335a1 | ||
|
|
1412c079ad | ||
|
|
6570c0c3b9 | ||
|
|
3a08ea799a | ||
|
|
e3fc244126 | ||
|
|
56938ca0a1 | ||
|
|
5d80642ea4 | ||
|
|
da4b084a8b | ||
|
|
86e1a37a00 | ||
|
|
ea34690709 | ||
|
|
c8df7cd2c0 | ||
|
|
628367b97b | ||
|
|
002816653e | ||
|
|
b05de8634d | ||
|
|
5088e700ad | ||
|
|
d2155e98ef | ||
|
|
7ec511da01 | ||
|
|
985cd8272b | ||
|
|
cd136194ad | ||
|
|
2e2ac71278 | ||
|
|
db4220fb20 | ||
|
|
84f70942e7 | ||
|
|
0af20b03e5 | ||
|
|
e16414b452 | ||
|
|
5dbc2a74a2 | ||
|
|
ad736bc190 | ||
|
|
0e9b71801a | ||
|
|
e80f0b2b43 | ||
|
|
c9042e52d4 | ||
|
|
8a78e37634 | ||
|
|
5e93f58530 | ||
|
|
a3851e0b08 | ||
|
|
eb45a457e9 | ||
|
|
1446d3490b | ||
|
|
579318af70 | ||
|
|
57bfae6774 | ||
|
|
2a92524546 | ||
|
|
7a5fa25b48 | ||
|
|
b3f3020793 | ||
|
|
650809e50d | ||
|
|
7308428f32 | ||
|
|
4dc3f1bcee | ||
|
|
faeb5f0c3b | ||
|
|
d985dfe821 | ||
|
|
ce5ae83689 | ||
|
|
c0428ee7ef | ||
|
|
aa3b2106d4 | ||
|
|
cf2d67ef3d | ||
|
|
c4d1e78f59 | ||
|
|
02e4a3aa82 | ||
|
|
a0b0c30be9 | ||
|
|
5c4cbc7fa2 | ||
|
|
5f2f12f803 | ||
|
|
c9cd0a87be | ||
|
|
668c475271 | ||
|
|
341910739e | ||
|
|
53a3dc52bc | ||
|
|
23b0a4a7f4 | ||
|
|
6afbf31750 | ||
|
|
3cd4306eec | ||
|
|
827191d2fc | ||
|
|
aaa34f717d | ||
|
|
fe83c2f81f | ||
|
|
17dead3309 | ||
|
|
979bd33dfb | ||
|
|
5128f072a8 | ||
|
|
2ad5b5cc2e | ||
|
|
24d8a96071 | ||
|
|
f1e4665aa2 | ||
|
|
1cbfea3a21 | ||
|
|
981e8e217d | ||
|
|
e7ca30f406 | ||
|
|
2832ca300f | ||
|
|
de5f413440 | ||
|
|
fbc14c61ea | ||
|
|
77e029a49f | ||
|
|
61b049ad35 | ||
|
|
b88f4a24d0 | ||
|
|
8c632f0d32 | ||
|
|
150a876c73 | ||
|
|
62c3b01e4f | ||
|
|
e1157f343b | ||
|
|
6a78739076 | ||
|
|
0794eb43e7 | ||
|
|
4ee54eac1d | ||
|
|
5851c46c81 | ||
|
|
a296559e79 | ||
|
|
1fd83f5e68 | ||
|
|
637487c573 | ||
|
|
4e98e7d0a2 | ||
|
|
12f65d800d | ||
|
|
45d09f8f51 | ||
|
|
2876c72fa9 | ||
|
|
9b4fdb493e | ||
|
|
47e21d6e04 | ||
|
|
84ab4a1c30 | ||
|
|
85c4304efd | ||
|
|
8f152f162b | ||
|
|
63b49f045a | ||
|
|
291e0736d6 | ||
|
|
4bfa6439d4 | ||
|
|
a8d7969a1d | ||
|
|
46bfa24af3 | ||
|
|
a8cb8e128d | ||
|
|
8cef0f5bf5 | ||
|
|
911baeb58b | ||
|
|
312960645b | ||
|
|
50cf285efb | ||
|
|
a214f4fff5 | ||
|
|
2981591c36 | ||
|
|
b08f90c99f | ||
|
|
ab8c739cd8 | ||
|
|
5c5108c28a | ||
|
|
3df7cfd605 | ||
|
|
1ff3d44dba | ||
|
|
c80ad90f72 | ||
|
|
3b4d1b8786 | ||
|
|
c66201c7e1 | ||
|
|
35c7c59455 | ||
|
|
85f98ab3eb | ||
|
|
dac75685be | ||
|
|
d7b5a8b298 | ||
|
|
d3ecaa740f | ||
|
|
b5a6765a3d | ||
|
|
3704573ef8 | ||
|
|
01fbf2ce4d | ||
|
|
96e7003449 | ||
|
|
80197b8856 | ||
|
|
0187bc671e | ||
|
|
31584daabe | ||
|
|
a6cb522fed | ||
|
|
f70be1e415 | ||
|
|
a2901f2b46 | ||
|
|
b61c66c3a9 | ||
|
|
c77f9ec202 | ||
|
|
2c5c35647f | ||
|
|
bf0fdbd10e | ||
|
|
731d317a42 | ||
|
|
e81579f752 | ||
|
|
9a10e98c0b | ||
|
|
27fdc139b7 | ||
|
|
0a00805afc | ||
|
|
7b38143fbd | ||
|
|
4c5ad1b7d7 | ||
|
|
d80cc962ad | ||
|
|
7ccabfa200 | ||
|
|
936d59cc52 | ||
|
|
fc16fb6099 | ||
|
|
c848cbc2e3 | ||
|
|
66fd0f0d8a | ||
|
|
c266f39f06 | ||
|
|
98a44fa4d7 | ||
|
|
c1d230f961 | ||
|
|
68108435ae | ||
|
|
e121bf1f62 | ||
|
|
4835c344b3 | ||
|
|
a589dec122 | ||
|
|
bc67d5c841 | ||
|
|
f3d5691c04 | ||
|
|
b98abc2457 | ||
|
|
7e527ccfb7 | ||
|
|
0f0c911845 | ||
|
|
e4818b967b | ||
|
|
ce3eede26f | ||
|
|
d98725c5e9 | ||
|
|
31a96d2945 | ||
|
|
845a321a43 | ||
|
|
87a44a28ef | ||
|
|
d5b9c3ee5a | ||
|
|
91db136cd1 | ||
|
|
f351ad4b66 | ||
|
|
fb6fb9abbd | ||
|
|
675c990486 | ||
|
|
6ee5cde4bb | ||
|
|
c8077f9430 | ||
|
|
6aabe9959e | ||
|
|
0b58d172d2 | ||
|
|
d7c6e293d7 | ||
|
|
c600bc867d | ||
|
|
f4140dd772 | ||
|
|
a2d8261d40 | ||
|
|
bce88a8873 | ||
|
|
b37e1a3ad6 | ||
|
|
35a088e0a6 | ||
|
|
b936cab039 | ||
|
|
34e4093408 | ||
|
|
d7f93c3cc0 | ||
|
|
d4c4926caa | ||
|
|
558c7db055 | ||
|
|
2ece59b51b | ||
|
|
7dbe39957c | ||
|
|
6fa46d35a5 | ||
|
|
b2a2b38ea8 | ||
|
|
12934da390 | ||
|
|
231bc18188 | ||
|
|
530cd180c5 | ||
|
|
2a92e7b920 | ||
|
|
019e057e29 | ||
|
|
9aa26f883e | ||
|
|
3f727e24b1 | ||
|
|
9e90bf1b20 | ||
|
|
db3964797f | ||
|
|
881efbda1b | ||
|
|
e9ce2ed5f2 | ||
|
|
53ac9eafbf | ||
|
|
9e095006a5 | ||
|
|
21b24c3ba6 | ||
|
|
139ecc10ce | ||
|
|
78ea143b46 | ||
|
|
174249ec15 | ||
|
|
2510ad7431 | ||
|
|
ba5e855a60 | ||
|
|
23627cf18d | ||
|
|
5e20c9a1ca | ||
|
|
933cf5f276 | ||
|
|
41316de659 | ||
|
|
041ccfd68e | ||
|
|
ad24c203a4 | ||
|
|
3fd28ce600 | ||
|
|
32df3bdf6e | ||
|
|
ba69e89e8c | ||
|
|
a8e0c48ddc | ||
|
|
66f6571086 | ||
|
|
8a3848e7b6 | ||
|
|
3f8486b480 | ||
|
|
b80be4f639 | ||
|
|
adb3a849b9 | ||
|
|
798499fda6 | ||
|
|
02fc5a165c | ||
|
|
b1b8edecfb | ||
|
|
3cd8d48809 | ||
|
|
f4672ad8c1 | ||
|
|
5a86490845 | ||
|
|
27dc843046 | ||
|
|
2f35d74902 | ||
|
|
8bd52ed744 | ||
|
|
f3e2a3c384 | ||
|
|
ecc6e8a532 | ||
|
|
9170576a38 | ||
|
|
f26baa0341 | ||
|
|
99dad953a4 | ||
|
|
c39bcdffd3 | ||
|
|
32f2223237 | ||
|
|
6176941853 | ||
|
|
af41dc83f7 | ||
|
|
a17e771eba | ||
|
|
19ecdb196e | ||
|
|
15880e6ea7 | ||
|
|
53ffa98662 | ||
|
|
021a334240 | ||
|
|
cfed293d48 | ||
|
|
d36bc185c8 | ||
|
|
7878203b03 | ||
|
|
3352220d39 | ||
|
|
bcfb1e7e52 | ||
|
|
e84b3c142c | ||
|
|
22f637b647 | ||
|
|
5d192ab6e5 | ||
|
|
9273d1629e | ||
|
|
27a12f080b | ||
|
|
3bfb497764 | ||
|
|
b849c7d382 | ||
|
|
8d4120583d | ||
|
|
402cdc7eda | ||
|
|
b02ea1a898 | ||
|
|
d709040f4b | ||
|
|
8a7a498da3 | ||
|
|
699736486b | ||
|
|
37e790ae19 | ||
|
|
6c0bd7d150 | ||
|
|
99e154d773 | ||
|
|
e4e43ae126 | ||
|
|
a07fac6180 | ||
|
|
93d4b00082 | ||
|
|
8abcc99ced | ||
|
|
73ab4b8895 | ||
|
|
86719f2065 | ||
|
|
5271fc1cac | ||
|
|
96ff7d9093 | ||
|
|
6f73d9e9c6 | ||
|
|
29b406a84b | ||
|
|
2b1e4b88d3 | ||
|
|
0f0085a776 | ||
|
|
ea28ed8261 | ||
|
|
c0e6327d3a | ||
|
|
459491e402 | ||
|
|
a4cddfa47d | ||
|
|
9a822bcfe8 | ||
|
|
5f12b9185f | ||
|
|
d958d2e5a0 | ||
|
|
823ca214e6 | ||
|
|
a33da450fd | ||
|
|
8b5f4d190c | ||
|
|
f1f3b7965a | ||
|
|
987be3507c | ||
|
|
1f4090fe0e | ||
|
|
029e2d2c46 | ||
|
|
7722f479e8 | ||
|
|
3ad4072183 | ||
|
|
6dfb9a1906 | ||
|
|
ad2924350d | ||
|
|
3bf51ee0c2 | ||
|
|
fce5051dcc | ||
|
|
446d8818b9 | ||
|
|
1566e29c19 | ||
|
|
6a2e35f2c4 | ||
|
|
b6d58774f4 | ||
|
|
758f94d3c6 | ||
|
|
9df0871754 | ||
|
|
3011150a3a | ||
|
|
05aa1fce71 | ||
|
|
df81f3274a | ||
|
|
143487a492 | ||
|
|
203fa04295 | ||
|
|
954fce3c67 | ||
|
|
821889148a | ||
|
|
4c248d8c2c | ||
|
|
deb75805d4 | ||
|
|
93110654da | ||
|
|
ff0c48d532 | ||
|
|
de18073814 | ||
|
|
0708af9545 | ||
|
|
1e85184c62 | ||
|
|
11d3b8d944 | ||
|
|
bffd4afb96 | ||
|
|
518a896521 | ||
|
|
2647ff141a | ||
|
|
ba0bac2aa5 | ||
|
|
862e2a3e49 | ||
|
|
d22fd32b05 | ||
|
|
391e5b7f8c | ||
|
|
c9d2a5f59a | ||
|
|
1f63b60021 | ||
|
|
a499b9f54e | ||
|
|
104505ea02 | ||
|
|
ee4002607c | ||
|
|
fd20582cdd | ||
|
|
43b0d07517 | ||
|
|
f83592a052 | ||
|
|
b3ee906749 | ||
|
|
5d69e9068a | ||
|
|
a79136b058 | ||
|
|
944af4d4a9 | ||
|
|
5e001be73a | ||
|
|
576a644b3a | ||
|
|
703557c8a6 | ||
|
|
d59a53b3f9 | ||
|
|
7b8f78c2d9 | ||
|
|
31ab9be79a | ||
|
|
5011fab85d | ||
|
|
92bdb9fdcc | ||
|
|
548e766c0b | ||
|
|
ff897f74a1 | ||
|
|
3d29c996ed | ||
|
|
42d57d1225 | ||
|
|
193fa9395a | ||
|
|
56cd839d5b | ||
|
|
7b446ee40d | ||
|
|
17027c4070 | ||
|
|
13d44f47ce | ||
|
|
550fbdeb1c | ||
|
|
a01cd7c497 | ||
|
|
c54afd600c | ||
|
|
4f911a0ea8 | ||
|
|
fb91f48722 | ||
|
|
69db60a614 | ||
|
|
c6d7f951aa | ||
|
|
04c005284c | ||
|
|
2d7f9697bf | ||
|
|
ae530492a2 | ||
|
|
87ed1e3b6d | ||
|
|
cc54466db9 | ||
|
|
cbdafe7e38 | ||
|
|
112cb76174 | ||
|
|
e56d41ab99 | ||
|
|
273dfd86ab | ||
|
|
871271fde5 | ||
|
|
14944872c4 | ||
|
|
07bcf3c446 | ||
|
|
8ed5585285 | ||
|
|
5ce226a467 | ||
|
|
c64f20a72b | ||
|
|
0c9c10a03a | ||
|
|
4a0df6b865 | ||
|
|
ba165572bf | ||
|
|
c3d6a10603 | ||
|
|
4efc86299d | ||
|
|
e8c7cf63fd | ||
|
|
698b034190 | ||
|
|
3988128c40 | ||
|
|
c768f47365 | ||
|
|
19a63abc54 | ||
|
|
75ec36bf9a | ||
|
|
d802f8e7fb | ||
|
|
6873e0308d | ||
|
|
66eb73088e | ||
|
|
ed81a13eb4 | ||
|
|
fbc1aae52d | ||
|
|
ba42c3e63f | ||
|
|
b24e820aa0 | ||
|
|
e8f6b3b77a | ||
|
|
8f13518c97 | ||
|
|
6afbc12074 | ||
|
|
6b0a56ceb9 | ||
|
|
ca92497e52 | ||
|
|
97d45ceaf2 | ||
|
|
aeb3841a6f | ||
|
|
c14d33d3c1 | ||
|
|
676e59e072 | ||
|
|
e7dcb6a03f | ||
|
|
fb95b7cc2b | ||
|
|
015dc3ac0d | ||
|
|
9d8a71b362 | ||
|
|
2eb212f393 | ||
|
|
34b268c15c | ||
|
|
9a203a64dc | ||
|
|
d80004e056 | ||
|
|
de32ed23a7 | ||
|
|
5aed2b315d | ||
|
|
48db6cfc4f | ||
|
|
aa7c5c281a | ||
|
|
87aeb7f889 | ||
|
|
3b3d6e413a | ||
|
|
b6432f2de3 | ||
|
|
9d0a28ccae | ||
|
|
c3bf0a3277 | ||
|
|
b516610c1e | ||
|
|
677e717cd7 | ||
|
|
c52584e057 | ||
|
|
b6767441db | ||
|
|
8745dbe67d | ||
|
|
a565d9473e | ||
|
|
4dbf07c3e0 | ||
|
|
f6eb4d9a6b | ||
|
|
5037967b82 | ||
|
|
4930ba48ce | ||
|
|
40d2092256 | ||
|
|
d2e9237740 | ||
|
|
b191b706c1 | ||
|
|
4d0f760ec8 | ||
|
|
65cda5365a | ||
|
|
1f2d1d086f | ||
|
|
418f3c3f19 | ||
|
|
72173e284c | ||
|
|
9cc13556aa | ||
|
|
298444f2bc | ||
|
|
deb1984289 | ||
|
|
814406d98a | ||
|
|
c054501103 | ||
|
|
c1d819c7e5 | ||
|
|
2a8e91f94d | ||
|
|
64f3e56039 | ||
|
|
819afab230 | ||
|
|
9fff064c55 | ||
|
|
1aa8d94378 | ||
|
|
d78bdde2c3 | ||
|
|
7b663b3432 | ||
|
|
9c4159915a | ||
|
|
dbb5830027 | ||
|
|
4fc4dbb656 | ||
|
|
d4f6d09cc9 | ||
|
|
44e44602d3 | ||
|
|
36066c5f26 | ||
|
|
361c6eed4b | ||
|
|
bb154fd40f | ||
|
|
cbee6e6faf |
29
.github/CODEOWNERS
vendored
29
.github/CODEOWNERS
vendored
@@ -1,32 +1,31 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku @psychedelicious
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
/docker/ @lstein @blessedcoolant @hipsterusername @ebr
|
||||
/scripts/ @ebr @lstein @hipsterusername
|
||||
/installer/ @lstein @ebr @hipsterusername
|
||||
/invokeai/assets @lstein @ebr @hipsterusername
|
||||
/invokeai/configs @lstein @hipsterusername
|
||||
/invokeai/version @lstein @blessedcoolant @hipsterusername
|
||||
/pyproject.toml @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/docker/ @lstein @blessedcoolant @psychedelicious @hipsterusername @ebr
|
||||
/scripts/ @ebr @lstein @psychedelicious @hipsterusername
|
||||
/installer/ @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/assets @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/configs @lstein @psychedelicious @hipsterusername
|
||||
/invokeai/version @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
|
||||
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
/invokeai/frontend/install @lstein @ebr @hipsterusername
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/CLI @lstein @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/install @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername
|
||||
|
||||
@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 18
|
||||
- name: setup node 20
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
node-version: '20'
|
||||
|
||||
- name: setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 8.15.6
|
||||
version: 10
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
|
||||
4
.github/workflows/python-checks.yml
vendored
4
.github/workflows/python-checks.yml
vendored
@@ -67,6 +67,10 @@ jobs:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
|
||||
- name: check pypi classifiers
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: uv run --no-project scripts/check_classifiers.py ./pyproject.toml
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: uv tool run ruff@0.11.2 check --output-format=github .
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -180,6 +180,7 @@ cython_debug/
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
.vscode/
|
||||
.zed/
|
||||
|
||||
# source installer files
|
||||
installer/*zip
|
||||
@@ -188,3 +189,4 @@ installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
@@ -39,7 +39,7 @@ nodes imported in the `__init__.py` file are loaded. See the README in the nodes
|
||||
folder for more examples:
|
||||
|
||||
```py
|
||||
from .cool_node import CoolInvocation
|
||||
from .cool_node import ResizeInvocation
|
||||
```
|
||||
|
||||
## Creating A New Invocation
|
||||
@@ -69,7 +69,10 @@ The first set of things we need to do when creating a new Invocation are -
|
||||
So let us do that.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -103,8 +106,12 @@ create your own custom field types later in this guide. For now, let's go ahead
|
||||
and use it.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -128,8 +135,12 @@ image: ImageField = InputField(description="The input image")
|
||||
Great. Now let us create our other inputs for `width` and `height`
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -163,8 +174,13 @@ that are provided by it by InvokeAI.
|
||||
Let us create this function first.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -191,8 +207,14 @@ all the necessary info related to image outputs. So let us use that.
|
||||
We will cover how to create your own output types later in this guide.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
|
||||
@invocation('resize')
|
||||
@@ -217,9 +239,15 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
|
||||
So let's do that.
|
||||
|
||||
```python
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
|
||||
@invocation("resize")
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
|
||||
@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu126 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
|
||||
<!-- links -->
|
||||
|
||||
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod/v4'
|
||||
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
|
||||
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
||||
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
||||
|
||||
@@ -71,7 +71,21 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
|
||||
7. Determine the `PyPI` index URL to use for installation, if any. This is necessary to get the right version of torch installed.
|
||||
|
||||
=== "Invoke v5 or later"
|
||||
=== "Invoke v5.12 and later"
|
||||
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu128`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v5.10.0 to v5.11.0"
|
||||
|
||||
- If you are on Windows or Linux with an Nvidia GPU, use `https://download.pytorch.org/whl/cu126`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
- If you are on Linux with an AMD GPU, use `https://download.pytorch.org/whl/rocm6.2.4`.
|
||||
- **In all other cases, do not use an index.**
|
||||
|
||||
=== "Invoke v5.0.0 to v5.9.1"
|
||||
|
||||
- If you are on Windows with an Nvidia GPU, use `https://download.pytorch.org/whl/cu124`.
|
||||
- If you are on Linux with no GPU, use `https://download.pytorch.org/whl/cpu`.
|
||||
|
||||
@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download
|
||||
|
||||
Download the most launcher for your operating system:
|
||||
Download the most recent launcher for your operating system:
|
||||
|
||||
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
|
||||
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)
|
||||
|
||||
@@ -13,6 +13,7 @@ If you'd prefer, you can also just download the whole node folder from the linke
|
||||
To use a community workflow, download the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
||||
|
||||
- Community Nodes
|
||||
+ [Anamorphic Tools](#anamorphic-tools)
|
||||
+ [Adapters-Linked](#adapters-linked-nodes)
|
||||
+ [Autostereogram](#autostereogram-nodes)
|
||||
+ [Average Images](#average-images)
|
||||
@@ -20,9 +21,12 @@ To use a community workflow, download the `.json` node graph file and load it in
|
||||
+ [Close Color Mask](#close-color-mask)
|
||||
+ [Clothing Mask](#clothing-mask)
|
||||
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
|
||||
+ [Curves](#curves)
|
||||
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
|
||||
+ [Enhance Detail](#enhance-detail)
|
||||
+ [Film Grain](#film-grain)
|
||||
+ [Flip Pose](#flip-pose)
|
||||
+ [Flux Ideal Size](#flux-ideal-size)
|
||||
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
|
||||
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
|
||||
+ [Grid to Gif](#grid-to-gif)
|
||||
@@ -61,6 +65,13 @@ To use a community workflow, download the `.json` node graph file and load it in
|
||||
- [Help](#help)
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Anamorphic Tools
|
||||
|
||||
**Description:** A set of nodes to perform anamorphic modifications to images, like lens blur, streaks, spherical distortion, and vignetting.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/anamorphic-tools
|
||||
|
||||
--------------------------------
|
||||
### Adapters Linked Nodes
|
||||
|
||||
@@ -132,6 +143,13 @@ Node Link: https://github.com/VeyDlin/clahe-node
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Curves
|
||||
|
||||
**Description:** Adjust an image's curve based on a user-defined string.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/curves-node
|
||||
|
||||
--------------------------------
|
||||
### Depth Map from Wavefront OBJ
|
||||
|
||||
@@ -162,6 +180,20 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/film-grain-node
|
||||
|
||||
--------------------------------
|
||||
### Flip Pose
|
||||
|
||||
**Description:** This node will flip an openpose image horizontally, recoloring it to make sure that it isn't facing the wrong direction. Note that it does not work with openpose hands.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/flip-pose-node
|
||||
|
||||
--------------------------------
|
||||
### Flux Ideal Size
|
||||
|
||||
**Description:** This node returns an ideal size to use for the first stage of a Flux image generation pipeline. Generating at the right size helps limit duplication and odd subject placement.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/flux-ideal-size
|
||||
|
||||
--------------------------------
|
||||
### Generative Grammar-Based Prompt Nodes
|
||||
|
||||
|
||||
@@ -23,6 +23,10 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
|
||||
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
|
||||
SqliteModelRelationshipRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
|
||||
from invokeai.app.services.names.names_default import SimpleNameService
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
@@ -113,7 +117,6 @@ class ApiDependencies:
|
||||
safe_globals=[torch.Tensor],
|
||||
ephemeral=True,
|
||||
),
|
||||
max_cache_size=0,
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](
|
||||
@@ -137,6 +140,8 @@ class ApiDependencies:
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
model_relationships = ModelRelationshipsService()
|
||||
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||
@@ -162,6 +167,8 @@ class ApiDependencies:
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
model_relationships=model_relationships,
|
||||
model_relationship_records=model_relationship_records,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from importlib.metadata import distributions
|
||||
from pathlib import Path
|
||||
from platform import python_version
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -44,24 +43,6 @@ class AppVersion(BaseModel):
|
||||
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
accelerate: str = Field(description="accelerate version")
|
||||
compel: str = Field(description="compel version")
|
||||
cuda: Optional[str] = Field(description="CUDA version")
|
||||
diffusers: str = Field(description="diffusers version")
|
||||
numpy: str = Field(description="Numpy version")
|
||||
opencv: str = Field(description="OpenCV version")
|
||||
onnx: str = Field(description="ONNX version")
|
||||
pillow: str = Field(description="Pillow (PIL) version")
|
||||
python: str = Field(description="Python version")
|
||||
torch: str = Field(description="PyTorch version")
|
||||
torchvision: str = Field(description="PyTorch Vision version")
|
||||
transformers: str = Field(description="transformers version")
|
||||
xformers: Optional[str] = Field(description="xformers version")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""App Config Response"""
|
||||
|
||||
@@ -76,27 +57,19 @@ async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
||||
|
||||
|
||||
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
|
||||
async def get_app_deps() -> AppDependencyVersions:
|
||||
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=dict[str, str])
|
||||
async def get_app_deps() -> dict[str, str]:
|
||||
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
|
||||
try:
|
||||
xformers = version("xformers")
|
||||
except PackageNotFoundError:
|
||||
xformers = None
|
||||
return AppDependencyVersions(
|
||||
accelerate=version("accelerate"),
|
||||
compel=version("compel"),
|
||||
cuda=torch.version.cuda,
|
||||
diffusers=version("diffusers"),
|
||||
numpy=version("numpy"),
|
||||
opencv=version("opencv-python"),
|
||||
onnx=version("onnx"),
|
||||
pillow=version("pillow"),
|
||||
python=python_version(),
|
||||
torch=torch.version.__version__,
|
||||
torchvision=version("torchvision"),
|
||||
transformers=version("transformers"),
|
||||
xformers=xformers,
|
||||
)
|
||||
cuda = torch.version.cuda or "N/A"
|
||||
except Exception:
|
||||
cuda = "N/A"
|
||||
|
||||
deps["CUDA"] = cuda
|
||||
|
||||
sorted_deps = dict(sorted(deps.items(), key=lambda item: item[0].lower()))
|
||||
|
||||
return sorted_deps
|
||||
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
class AddImagesToBoardResult(BaseModel):
|
||||
board_id: str = Field(description="The id of the board the images were added to")
|
||||
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(BaseModel):
|
||||
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="add_image_to_board",
|
||||
@@ -23,17 +14,26 @@ class RemoveImagesFromBoardResult(BaseModel):
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_image_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
@@ -45,14 +45,25 @@ async def add_image_to_board(
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_image_from_board(
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
):
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
@@ -72,16 +83,25 @@ async def add_images_to_board(
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
try:
|
||||
added_image_names: list[str] = []
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
board_id=board_id,
|
||||
image_name=image_name,
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
@@ -100,13 +120,20 @@ async def remove_images_from_board(
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
try:
|
||||
removed_image_names: list[str] = []
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -146,7 +146,7 @@ async def list_boards(
|
||||
response_model=list[str],
|
||||
)
|
||||
async def list_all_board_image_names(
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
|
||||
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
|
||||
) -> list[str]:
|
||||
|
||||
@@ -1,24 +1,34 @@
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.images.images_common import (
|
||||
DeleteImagesResult,
|
||||
ImageDTO,
|
||||
ImageUrlsDTO,
|
||||
StarredImagesResult,
|
||||
UnstarredImagesResult,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
@@ -27,6 +37,19 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ResizeToDimensions(BaseModel):
|
||||
width: int = Field(..., gt=0)
|
||||
height: int = Field(..., gt=0)
|
||||
|
||||
MAX_SIZE: ClassVar[int] = 4096 * 4096
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_total_output_size(self):
|
||||
if self.width * self.height > self.MAX_SIZE:
|
||||
raise ValueError(f"Max total output size for resizing is {self.MAX_SIZE} pixels")
|
||||
return self
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/upload",
|
||||
operation_id="upload_image",
|
||||
@@ -46,6 +69,11 @@ async def upload_image(
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
resize_to: Optional[str] = Body(
|
||||
default=None,
|
||||
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
|
||||
examples=['"[1024,1024]"'],
|
||||
),
|
||||
metadata: Optional[str] = Body(
|
||||
default=None,
|
||||
description="The metadata to associate with the image, must be a stringified JSON dict",
|
||||
@@ -59,13 +87,33 @@ async def upload_image(
|
||||
contents = await file.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
if crop_visible:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
if crop_visible:
|
||||
try:
|
||||
bbox = pil_image.getbbox()
|
||||
pil_image = pil_image.crop(bbox)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to crop image")
|
||||
|
||||
if resize_to:
|
||||
try:
|
||||
dims = json.loads(resize_to)
|
||||
resize_dims = ResizeToDimensions(**dims)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
|
||||
|
||||
try:
|
||||
# heuristic_resize_fast expects an RGB or RGBA image
|
||||
pil_rgba = pil_image.convert("RGBA")
|
||||
np_image = pil_to_np(pil_rgba)
|
||||
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
|
||||
pil_image = np_to_pil(np_image)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to resize image")
|
||||
|
||||
extracted_metadata = extract_metadata_from_image(
|
||||
pil_image=pil_image,
|
||||
invokeai_metadata_override=metadata,
|
||||
@@ -112,18 +160,30 @@ async def create_image_upload_entry(
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
) -> DeleteImagesResult:
|
||||
"""Deletes an image"""
|
||||
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except Exception:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
|
||||
@images_router.delete("/intermediates", operation_id="clear_intermediates")
|
||||
async def clear_intermediates() -> int:
|
||||
@@ -335,23 +395,52 @@ async def list_image_dtos(
|
||||
return image_dtos
|
||||
|
||||
|
||||
class DeleteImagesFromListResult(BaseModel):
|
||||
deleted_images: list[str]
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
|
||||
async def delete_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesFromListResult:
|
||||
) -> DeleteImagesResult:
|
||||
try:
|
||||
deleted_images: list[str] = []
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
|
||||
async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
"""Deletes all images that are uncategorized"""
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id="none", categories=None, is_intermediate=None
|
||||
)
|
||||
|
||||
try:
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
@@ -360,36 +449,50 @@ class ImagesUpdatedFromListResult(BaseModel):
|
||||
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
|
||||
async def star_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
) -> StarredImagesResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
starred_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=True)
|
||||
)
|
||||
starred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
return StarredImagesResult(
|
||||
starred_images=list(starred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
|
||||
async def unstar_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
) -> UnstarredImagesResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
unstarred_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=False)
|
||||
)
|
||||
unstarred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except Exception:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
return UnstarredImagesResult(
|
||||
unstarred_images=list(unstarred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -460,3 +563,61 @@ async def get_bulk_download_item(
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image names")
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/images_by_names",
|
||||
operation_id="get_images_by_names",
|
||||
responses={200: {"model": list[ImageDTO]}},
|
||||
)
|
||||
async def get_images_by_names(
|
||||
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
|
||||
) -> list[ImageDTO]:
|
||||
"""Gets image DTOs for the specified image names. Maintains order of input names."""
|
||||
|
||||
try:
|
||||
image_service = ApiDependencies.invoker.services.images
|
||||
|
||||
# Fetch DTOs preserving the order of requested names
|
||||
image_dtos: list[ImageDTO] = []
|
||||
for name in image_names:
|
||||
try:
|
||||
dto = image_service.get_dto(name)
|
||||
image_dtos.append(dto)
|
||||
except Exception:
|
||||
# Skip missing images - they may have been deleted between name fetch and DTO fetch
|
||||
continue
|
||||
|
||||
return image_dtos
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image DTOs")
|
||||
|
||||
@@ -41,6 +41,7 @@ from invokeai.backend.model_manager.starter_models import (
|
||||
STARTER_BUNDLES,
|
||||
STARTER_MODELS,
|
||||
StarterModel,
|
||||
StarterModelBundle,
|
||||
StarterModelWithoutDependencies,
|
||||
)
|
||||
|
||||
@@ -291,7 +292,7 @@ async def get_hugging_face_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -449,7 +450,7 @@ async def install_model(
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
examples=[{"name": "string", "description": "string"}],
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
@@ -799,7 +800,7 @@ async def convert_model(
|
||||
|
||||
class StarterModelResponse(BaseModel):
|
||||
starter_models: list[StarterModel]
|
||||
starter_bundles: dict[str, list[StarterModel]]
|
||||
starter_bundles: dict[str, StarterModelBundle]
|
||||
|
||||
|
||||
def get_is_installed(
|
||||
@@ -833,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
for bundle in starter_bundles.values():
|
||||
for model in bundle:
|
||||
for model in bundle.models:
|
||||
model.is_installed = get_is_installed(model, installed_models)
|
||||
# Remove already-installed dependencies
|
||||
missing_deps: list[StarterModelWithoutDependencies] = []
|
||||
@@ -893,6 +894,12 @@ class HFTokenHelper:
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
@classmethod
|
||||
def reset_token(cls) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.logout()
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
@@ -915,3 +922,8 @@ async def do_hf_login(
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
|
||||
async def reset_hf_token() -> HFTokenStatus:
|
||||
return HFTokenHelper.reset_token()
|
||||
|
||||
215
invokeai/app/api/routers/model_relationships.py
Normal file
215
invokeai/app/api/routers/model_relationships.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""FastAPI route for model relationship records."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
model_relationships_router = APIRouter(prefix="/v1/model_relationships", tags=["model_relationships"])
|
||||
|
||||
# === Schemas ===
|
||||
|
||||
|
||||
class ModelRelationshipCreateRequest(BaseModel):
|
||||
model_key_1: str = Field(
|
||||
...,
|
||||
description="The key of the first model in the relationship",
|
||||
examples=[
|
||||
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
||||
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
||||
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
|
||||
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
||||
"12345678-90ab-cdef-1234-567890abcdef",
|
||||
"fedcba98-7654-3210-fedc-ba9876543210",
|
||||
],
|
||||
)
|
||||
model_key_2: str = Field(
|
||||
...,
|
||||
description="The key of the second model in the relationship",
|
||||
examples=[
|
||||
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
||||
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
|
||||
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
|
||||
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
|
||||
"10f7807b-ab54-46a9-ab03-600e88c630a1",
|
||||
"f6c1d267-cf87-4ee0-bee0-37e791eacab7",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ModelRelationshipBatchRequest(BaseModel):
|
||||
model_keys: List[str] = Field(
|
||||
...,
|
||||
description="List of model keys to fetch related models for",
|
||||
examples=[
|
||||
[
|
||||
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
|
||||
"ac32b914-10ab-496e-a24a-3068724b9c35",
|
||||
],
|
||||
[
|
||||
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
|
||||
"12345678-90ab-cdef-1234-567890abcdef",
|
||||
"fedcba98-7654-3210-fedc-ba9876543210",
|
||||
],
|
||||
[
|
||||
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# === Routes ===
|
||||
|
||||
|
||||
@model_relationships_router.get(
|
||||
"/i/{model_key}",
|
||||
operation_id="get_related_models",
|
||||
response_model=list[str],
|
||||
responses={
|
||||
200: {
|
||||
"description": "A list of related model keys was retrieved successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": [
|
||||
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
|
||||
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
||||
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
404: {"description": "The specified model could not be found"},
|
||||
422: {"description": "Validation error"},
|
||||
},
|
||||
)
|
||||
async def get_related_models(
|
||||
model_key: str = Path(..., description="The key of the model to get relationships for"),
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get a list of model keys related to a given model.
|
||||
"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@model_relationships_router.post(
|
||||
"/",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "The relationship was successfully created"},
|
||||
400: {"description": "Invalid model keys or self-referential relationship"},
|
||||
409: {"description": "The relationship already exists"},
|
||||
422: {"description": "Validation error"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
summary="Add Model Relationship",
|
||||
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
|
||||
)
|
||||
async def add_model_relationship(
|
||||
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate"),
|
||||
) -> None:
|
||||
"""
|
||||
Add a relationship between two models.
|
||||
|
||||
Relationships are bidirectional and will be accessible from both models.
|
||||
|
||||
- Raises 400 if keys are invalid or identical.
|
||||
- Raises 409 if the relationship already exists.
|
||||
"""
|
||||
try:
|
||||
if req.model_key_1 == req.model_key_2:
|
||||
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
|
||||
|
||||
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
|
||||
req.model_key_1,
|
||||
req.model_key_2,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@model_relationships_router.delete(
|
||||
"/",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
204: {"description": "The relationship was successfully removed"},
|
||||
400: {"description": "Invalid model keys or self-referential relationship"},
|
||||
404: {"description": "The relationship does not exist"},
|
||||
422: {"description": "Validation error"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
summary="Remove Model Relationship",
|
||||
description="Removes a **bidirectional** relationship between two models. The relationship must already exist.",
|
||||
)
|
||||
async def remove_model_relationship(
|
||||
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect"),
|
||||
) -> None:
|
||||
"""
|
||||
Removes a bidirectional relationship between two model keys.
|
||||
|
||||
- Raises 400 if attempting to unlink a model from itself.
|
||||
- Raises 404 if the relationship was not found.
|
||||
"""
|
||||
try:
|
||||
if req.model_key_1 == req.model_key_2:
|
||||
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
|
||||
|
||||
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
|
||||
req.model_key_1,
|
||||
req.model_key_2,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@model_relationships_router.post(
|
||||
"/batch",
|
||||
operation_id="get_related_models_batch",
|
||||
response_model=List[str],
|
||||
responses={
|
||||
200: {
|
||||
"description": "Related model keys retrieved successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": [
|
||||
"ca562b14-995e-4a42-90c1-9528f1a5921d",
|
||||
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
|
||||
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
|
||||
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
|
||||
"c382eaa3-0e28-4ab0-9446-408667699aeb",
|
||||
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
|
||||
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
422: {"description": "Validation error"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
summary="Get Related Model Keys (Batch)",
|
||||
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering.",
|
||||
)
|
||||
async def get_related_models_batch(
|
||||
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections"),
|
||||
) -> list[str]:
|
||||
"""
|
||||
Accepts multiple model keys and returns a flat list of all unique related keys.
|
||||
|
||||
Useful when working with multiple selections in the UI or cross-model comparisons.
|
||||
"""
|
||||
try:
|
||||
all_related: set[str] = set()
|
||||
for key in req.model_keys:
|
||||
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
|
||||
all_related.update(related)
|
||||
return list(all_related)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -14,13 +14,15 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
@@ -58,17 +60,19 @@ async def enqueue_batch(
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
try:
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list",
|
||||
operation_id="list_queue_items",
|
||||
responses={
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def list_queue_items(
|
||||
@@ -77,12 +81,42 @@ async def list_queue_items(
|
||||
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
|
||||
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
|
||||
priority: int = Query(default=0, description="The pagination cursor priority"),
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets all queue items (without graphs)"""
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
"""Gets cursor-paginated queue items"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
|
||||
)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id,
|
||||
limit=limit,
|
||||
status=status,
|
||||
cursor=cursor,
|
||||
priority=priority,
|
||||
destination=destination,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list_all",
|
||||
operation_id="list_all_queue_items",
|
||||
responses={
|
||||
200: {"model": list[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def list_all_queue_items(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -94,7 +128,10 @@ async def resume(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -106,7 +143,10 @@ async def Pause(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -118,7 +158,25 @@ async def cancel_all_except_current(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
"""Immediately cancels all queue items except in-processing items"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/delete_all_except_current",
|
||||
operation_id="delete_all_except_current",
|
||||
responses={200: {"model": DeleteAllExceptCurrentResult}},
|
||||
)
|
||||
async def delete_all_except_current(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
"""Immediately deletes all queue items except in-processing items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -131,7 +189,12 @@ async def cancel_by_batch_ids(
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
|
||||
queue_id=queue_id, batch_ids=batch_ids
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -144,9 +207,12 @@ async def cancel_by_destination(
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -159,7 +225,10 @@ async def retry_items_by_id(
|
||||
item_ids: list[int] = Body(description="The queue item ids to retry"),
|
||||
) -> RetryItemsResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -173,11 +242,14 @@ async def clear(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -191,7 +263,10 @@ async def prune(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -205,7 +280,10 @@ async def get_current_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -219,7 +297,10 @@ async def get_next_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -233,9 +314,12 @@ async def get_queue_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -250,7 +334,10 @@ async def get_batch_status(
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -266,7 +353,27 @@ async def get_queue_item(
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
"""Gets a queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.delete(
|
||||
"/{queue_id}/i/{item_id}",
|
||||
operation_id="delete_queue_item",
|
||||
)
|
||||
async def delete_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to delete"),
|
||||
) -> None:
|
||||
"""Deletes a queue item"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -281,8 +388,12 @@ async def cancel_queue_item(
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -295,6 +406,27 @@ async def counts_by_destination(
|
||||
destination: str = Query(description="The destination to query"),
|
||||
) -> SessionQueueCountsByDestination:
|
||||
"""Gets the counts of queue items by destination"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
|
||||
|
||||
|
||||
@session_queue_router.delete(
|
||||
"/{queue_id}/d/{destination}",
|
||||
operation_id="delete_by_destination",
|
||||
responses={200: {"model": DeleteByDestinationResult}},
|
||||
)
|
||||
async def delete_by_destination(
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Path(description="The destination to query"),
|
||||
) -> DeleteByDestinationResult:
|
||||
"""Deletes all items with the given destination"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
|
||||
|
||||
@@ -22,6 +22,7 @@ from invokeai.app.api.routers import (
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
model_relationships,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
@@ -125,6 +126,7 @@ app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(model_relationships.model_relationships_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
@@ -156,7 +158,7 @@ web_root_path = Path(list(web_dir.__path__)[0])
|
||||
try:
|
||||
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
|
||||
except RuntimeError:
|
||||
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
|
||||
app.mount(
|
||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||
) # docs favicon is in here
|
||||
|
||||
@@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -20,12 +22,14 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
@@ -72,13 +76,24 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Special = "special"
|
||||
|
||||
|
||||
class Bottleneck(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The bottleneck of an invocation.
|
||||
- `Network`: The invocation's execution is network-bound.
|
||||
- `GPU`: The invocation's execution is GPU-bound.
|
||||
"""
|
||||
|
||||
Network = "network"
|
||||
GPU = "gpu"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
This is used internally by the @invocation decorator logic. Do not use this directly.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
tags: Optional[list[str]] = Field(default=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: str = Field(
|
||||
@@ -93,6 +108,11 @@ class UIConfigBase(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class OriginalModelField(TypedDict):
|
||||
annotation: Any
|
||||
field_info: FieldInfo
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
Base class for all invocation outputs.
|
||||
@@ -100,6 +120,12 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
output_meta: Optional[dict[str, JsonValue]] = Field(
|
||||
default=None,
|
||||
description="Optional dictionary of metadata for the invocation output, unrelated to the invocation's actual output value. This is not exposed as an output field.",
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
@@ -115,6 +141,9 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
|
||||
"""The original model fields, before any modifications were made by the @invocation_output decorator."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
@@ -148,7 +177,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
@classmethod
|
||||
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||
def get_output_annotation(cls) -> Type[BaseInvocationOutput]:
|
||||
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@@ -180,7 +209,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
"""
|
||||
for field_name, field in self.model_fields.items():
|
||||
for field_name, field in type(self).model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
@@ -195,9 +224,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
raise RequiredConnectionException(type(self).model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
raise MissingInputException(type(self).model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if services.configuration.node_cache_size == 0:
|
||||
@@ -235,6 +264,8 @@ class BaseInvocation(ABC, BaseModel):
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
bottleneck: ClassVar[Bottleneck]
|
||||
|
||||
UIConfig: ClassVar[UIConfigBase]
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -245,6 +276,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
|
||||
"""The original model fields, before any modifications were made by the @invocation decorator."""
|
||||
|
||||
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
@@ -256,6 +290,26 @@ class InvocationRegistry:
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
|
||||
invocation_type = invocation.get_type()
|
||||
node_pack = invocation.UIConfig.node_pack
|
||||
|
||||
# Log a warning when an existing invocation is being clobbered by the one we are registering
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
if clobbered_invocation is not None:
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
|
||||
|
||||
if clobbered_node_pack == "invokeai":
|
||||
# The invocation being clobbered is a core invocation
|
||||
logger.warning(f'Overriding core node "{invocation_type}" with node from "{node_pack}"')
|
||||
else:
|
||||
# The invocation being clobbered is a custom invocation
|
||||
logger.warning(
|
||||
f'Overriding node "{invocation_type}" from "{node_pack}" with node from "{clobbered_node_pack}"'
|
||||
)
|
||||
cls._invocation_classes.remove(clobbered_invocation)
|
||||
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@@ -314,6 +368,15 @@ class InvocationRegistry:
|
||||
@classmethod
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
output_type = output.get_type()
|
||||
|
||||
# Log a warning when an existing invocation is being clobbered by the one we are registering
|
||||
clobbered_output = InvocationRegistry.get_output_for_type(output_type)
|
||||
if clobbered_output is not None:
|
||||
# TODO(psyche): We do not record the node pack of the output, so we cannot log it here
|
||||
logger.warning(f'Overriding invocation output "{output_type}"')
|
||||
cls._output_classes.remove(clobbered_output)
|
||||
|
||||
cls._output_classes.add(output)
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@@ -322,6 +385,11 @@ class InvocationRegistry:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_map(cls) -> dict[str, type[BaseInvocationOutput]]:
|
||||
"""Gets a map of all output types to their output classes."""
|
||||
return {i.get_type(): i for i in cls.get_output_classes()}
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
@@ -347,6 +415,11 @@ class InvocationRegistry:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in cls.get_output_classes())
|
||||
|
||||
@classmethod
|
||||
def get_output_for_type(cls, output_type: str) -> type[BaseInvocationOutput] | None:
|
||||
"""Gets the output class for a given output type."""
|
||||
return cls.get_outputs_map().get(output_type)
|
||||
|
||||
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
@@ -354,11 +427,12 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
"bottleneck",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type", "output_meta"}
|
||||
|
||||
|
||||
class _Model(BaseModel):
|
||||
@@ -425,11 +499,53 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
|
||||
ui_type = field.json_schema_extra.get("ui_type", None)
|
||||
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
|
||||
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
|
||||
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
return None
|
||||
|
||||
|
||||
class NoDefaultSentinel:
|
||||
pass
|
||||
|
||||
|
||||
def validate_field_default(
|
||||
cls_name: str, field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo
|
||||
) -> None:
|
||||
"""Validates the default value of a field against its pydantic field definition."""
|
||||
|
||||
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
|
||||
|
||||
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
|
||||
# We store the original default value in the json_schema_extra dict, so we can validate it here.
|
||||
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
|
||||
|
||||
if orig_default is NoDefaultSentinel:
|
||||
return
|
||||
|
||||
# To validate the default value, we can create a temporary pydantic model with the field we are validating as its
|
||||
# only field. Then validate the default value against this temporary model.
|
||||
TempDefaultValidator = cast(BaseModel, create_model(cls_name, **{field_name: (annotation, field_info)}))
|
||||
|
||||
try:
|
||||
TempDefaultValidator.model_validate({field_name: orig_default})
|
||||
except Exception as e:
|
||||
raise InvalidFieldError(
|
||||
f'Default value for field "{field_name}" on invocation "{invocation_type}" is invalid, {e}'
|
||||
) from e
|
||||
|
||||
|
||||
def is_optional(annotation: Any) -> bool:
|
||||
"""
|
||||
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
|
||||
"""
|
||||
origin = typing.get_origin(annotation)
|
||||
# PEP 604 unions (int|None) have origin types.UnionType
|
||||
is_union = origin is typing.Union or origin is types.UnionType
|
||||
if not is_union:
|
||||
return False
|
||||
return any(arg is type(None) for arg in typing.get_args(annotation))
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
@@ -438,6 +554,7 @@ def invocation(
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
classification: Classification = Classification.Stable,
|
||||
bottleneck: Bottleneck = Bottleneck.GPU,
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Registers an invocation.
|
||||
@@ -449,6 +566,7 @@ def invocation(
|
||||
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
|
||||
:param Bottleneck bottleneck: The bottleneck of the invocation. Defaults to Bottleneck.GPU. Use Network if the invocation is network-bound.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
@@ -460,27 +578,28 @@ def invocation(
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
node_pack = cls.__module__.split(".")[0]
|
||||
|
||||
# Handle the case where an existing node is being clobbered by the one we are registering
|
||||
if invocation_type in InvocationRegistry.get_invocation_types():
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
assert clobbered_invocation is not None
|
||||
|
||||
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
|
||||
|
||||
if clobbered_node_pack == "invokeai":
|
||||
# The node being clobbered is a core node
|
||||
raise ValueError(
|
||||
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
|
||||
)
|
||||
else:
|
||||
# The node being clobbered is a custom node
|
||||
raise ValueError(
|
||||
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
|
||||
)
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
original_model_fields: dict[str, OriginalModelField] = {}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
|
||||
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
|
||||
|
||||
if field_info.default is None and not is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconfig: dict[str, Any] = {}
|
||||
uiconfig["title"] = title
|
||||
@@ -496,7 +615,7 @@ def invocation(
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
uiconfig["version"] = version
|
||||
else:
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
uiconfig["version"] = "1.0.0"
|
||||
|
||||
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||
@@ -504,6 +623,8 @@ def invocation(
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
cls.bottleneck = bottleneck
|
||||
|
||||
# Add the invocation type to the model.
|
||||
|
||||
# You'd be tempted to just add the type field and rebuild the model, like this:
|
||||
@@ -513,11 +634,27 @@ def invocation(
|
||||
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
|
||||
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
invocation_type_annotation = Literal[invocation_type]
|
||||
|
||||
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
|
||||
# This cast makes the type annotation match the class's true type.
|
||||
invocation_type_field_info = cast(
|
||||
FieldInfo,
|
||||
Field(title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
|
||||
)
|
||||
|
||||
fields["type"] = (invocation_type_annotation, invocation_type_field_info)
|
||||
|
||||
# Invocation outputs must be registered using the @invocation_output decorator, but it is possible that the
|
||||
# output is registered _after_ this invocation is registered. It depends on module import ordering.
|
||||
#
|
||||
# We can only confirm the output for an invocation is registered after all modules are imported. There's
|
||||
# only really one good time to do that - during application startup, in `run_app.py`, after loading all
|
||||
# custom nodes.
|
||||
#
|
||||
# We can still do some basic validation here - ensure the invoke method is defined and returns an instance
|
||||
# of BaseInvocationOutput.
|
||||
|
||||
# Validate the `invoke()` method is implemented
|
||||
if "invoke" in cls.__abstractmethods__:
|
||||
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
|
||||
@@ -539,17 +676,13 @@ def invocation(
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(invocation_type_annotation, invocation_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
|
||||
new_class.__doc__ = docstring
|
||||
new_class._original_model_fields = original_model_fields
|
||||
|
||||
InvocationRegistry.register_invocation(cls)
|
||||
InvocationRegistry.register_invocation(new_class)
|
||||
|
||||
return cls
|
||||
return new_class
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -572,29 +705,41 @@ def invocation_output(
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
if output_type in InvocationRegistry.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
|
||||
# Add the output type to the model.
|
||||
fields: dict[str, tuple[Any, FieldInfo]] = {}
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(
|
||||
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
|
||||
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
if field_info.default is not PydanticUndefined and is_optional(annotation):
|
||||
annotation = annotation | None
|
||||
fields[field_name] = (annotation, field_info)
|
||||
|
||||
# Add the output type to the model.
|
||||
output_type_annotation = Literal[output_type]
|
||||
|
||||
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
|
||||
# This cast makes the type annotation match the class's true type.
|
||||
output_type_field_info = cast(
|
||||
FieldInfo,
|
||||
Field(title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
|
||||
)
|
||||
|
||||
fields["type"] = (output_type_annotation, output_type_field_info)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
__base__=cls,
|
||||
__module__=cls.__module__,
|
||||
type=(output_type_annotation, output_type_field),
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
|
||||
new_class.__doc__ = docstring
|
||||
|
||||
InvocationRegistry.register_output(cls)
|
||||
InvocationRegistry.register_output(new_class)
|
||||
|
||||
return cls
|
||||
return new_class
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -64,7 +64,6 @@ class ImageBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
default=[],
|
||||
min_length=1,
|
||||
description="The images to batch over",
|
||||
)
|
||||
@@ -120,7 +119,6 @@ class StringBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
|
||||
|
||||
strings: list[str] = InputField(
|
||||
default=[],
|
||||
min_length=1,
|
||||
description="The strings to batch over",
|
||||
)
|
||||
@@ -176,7 +174,6 @@ class IntegerBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
|
||||
|
||||
integers: list[int] = InputField(
|
||||
default=[],
|
||||
min_length=1,
|
||||
description="The integers to batch over",
|
||||
)
|
||||
@@ -230,7 +227,6 @@ class FloatBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
|
||||
|
||||
floats: list[float] = InputField(
|
||||
default=[],
|
||||
min_length=1,
|
||||
description="The floats to batch over",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel import Compel, ReturnedEmbeddingsType, SplitLongTextMode
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
@@ -104,6 +104,7 @@ class CompelInvocation(BaseInvocation):
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -113,6 +114,13 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
del compel
|
||||
del patched_tokenizer
|
||||
del tokenizer
|
||||
del ti_manager
|
||||
del text_encoder
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
@@ -205,6 +213,7 @@ class SDXLPromptInvocationBase:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -220,7 +229,10 @@ class SDXLPromptInvocationBase:
|
||||
else:
|
||||
c_pooled = None
|
||||
|
||||
del compel
|
||||
del patched_tokenizer
|
||||
del tokenizer
|
||||
del ti_manager
|
||||
del text_encoder
|
||||
del text_encoder_info
|
||||
|
||||
|
||||
@@ -274,12 +274,12 @@ class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard
|
||||
title="Enhance Image",
|
||||
tags=["enhance", "image"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class InvokeImageEnhanceInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies processing from PIL's ImageEnhance module. Originally created by @dwringer"""
|
||||
|
||||
image: ImageField = InputField(default=None, description="The image for which to apply processing")
|
||||
image: ImageField = InputField(description="The image for which to apply processing")
|
||||
invert: bool = InputField(default=False, description="Whether to invert the image colors")
|
||||
color: float = InputField(ge=0, default=1.0, description="Color enhancement factor")
|
||||
contrast: float = InputField(ge=0, default=1.0, description="Contrast enhancement factor")
|
||||
|
||||
@@ -22,7 +22,11 @@ from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.app.util.controlnet_utils import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
heuristic_resize_fast,
|
||||
)
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
|
||||
|
||||
@@ -109,7 +113,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class HeuristicResizeInvocation(BaseInvocation):
|
||||
@@ -122,7 +126,7 @@ class HeuristicResizeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
np_img = pil_to_np(image)
|
||||
np_resized = heuristic_resize(np_img, (self.width, self.height))
|
||||
np_resized = heuristic_resize_fast(np_img, (self.width, self.height))
|
||||
resized = np_to_pil(np_resized)
|
||||
image_dto = context.images.save(image=resized)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -42,15 +44,13 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.2.0",
|
||||
version="1.3.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
"""Creates mask for denoising."""
|
||||
|
||||
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
edge_radius: int = InputField(
|
||||
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
||||
)
|
||||
mask: ImageField = InputField(description="Image which will be masked", ui_order=1)
|
||||
edge_radius: int = InputField(default=16, ge=0, description="How far to expand the edges of the mask", ui_order=2)
|
||||
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
||||
minimum_denoise: float = InputField(
|
||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||
@@ -81,45 +81,110 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
# Resize the mask_image. Makes the filter 64x faster and doesn't hurt quality in latent scale anyway
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // LATENT_SCALE_FACTOR,
|
||||
mask_image.height // LATENT_SCALE_FACTOR,
|
||||
),
|
||||
resample=Image.Resampling.BILINEAR,
|
||||
)
|
||||
|
||||
mask_np_orig = np.array(mask_image, dtype=np.float32)
|
||||
|
||||
self.edge_radius = self.edge_radius // LATENT_SCALE_FACTOR # scale the edge radius to match the mask size
|
||||
|
||||
if self.edge_radius > 0:
|
||||
mask_np = 255 - mask_np_orig # invert so 0 is unmasked (higher values = higher denoise strength)
|
||||
dilated_mask = mask_np.copy()
|
||||
|
||||
# Create kernel based on coherence mode
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
# Create a circular distance kernel that fades from center outward
|
||||
kernel_size = self.edge_radius * 2 + 1
|
||||
center = self.edge_radius
|
||||
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
||||
for i in range(kernel_size):
|
||||
for j in range(kernel_size):
|
||||
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
|
||||
if dist <= self.edge_radius:
|
||||
kernel[i, j] = 1.0 - (dist / self.edge_radius)
|
||||
else: # Gaussian Blur or Staged
|
||||
# Create a Gaussian kernel
|
||||
kernel_size = self.edge_radius * 2 + 1
|
||||
kernel = cv2.getGaussianKernel(
|
||||
kernel_size, self.edge_radius / 2.5
|
||||
) # 2.5 is a magic number (standard deviation capturing)
|
||||
kernel = kernel * kernel.T # Make 2D gaussian kernel
|
||||
kernel = kernel / np.max(kernel) # Normalize center to 1.0
|
||||
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
# Ensure values outside radius are 0
|
||||
center = self.edge_radius
|
||||
for i in range(kernel_size):
|
||||
for j in range(kernel_size):
|
||||
dist = np.sqrt((i - center) ** 2 + (j - center) ** 2)
|
||||
if dist > self.edge_radius:
|
||||
kernel[i, j] = 0
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
blur_tensor[blur_tensor < 0] = 0.0
|
||||
# 2D max filter
|
||||
mask_tensor = torch.tensor(mask_np)
|
||||
kernel_tensor = torch.tensor(kernel)
|
||||
dilated_mask = 255 - self.max_filter2D_torch(mask_tensor, kernel_tensor).cpu()
|
||||
dilated_mask = dilated_mask.numpy()
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
threshold = (1 - self.minimum_denoise) * 255
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
# wherever expanded mask is darker than the original mask but original was above threshhold, set it to the threshold
|
||||
# makes any expansion areas drop to threshhold. Raising minimum across the image happen outside of this if
|
||||
threshold_mask = (dilated_mask < mask_np_orig) & (mask_np_orig > threshold)
|
||||
dilated_mask = np.where(threshold_mask, threshold, mask_np_orig)
|
||||
|
||||
# wherever expanded mask is less than 255 but greater than threshold, drop it to threshold (minimum denoise)
|
||||
threshold_mask = (dilated_mask > threshold) & (dilated_mask < 255)
|
||||
dilated_mask = np.where(threshold_mask, threshold, dilated_mask)
|
||||
|
||||
else:
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
dilated_mask = mask_np_orig.copy()
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
# convert to tensor
|
||||
dilated_mask = np.clip(dilated_mask, 0, 255).astype(np.uint8)
|
||||
mask_tensor = torch.tensor(dilated_mask, device=torch.device("cpu"))
|
||||
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
# binary mask for compositing
|
||||
expanded_mask = np.where((dilated_mask < 255), 0, 255)
|
||||
expanded_mask_image = Image.fromarray(expanded_mask.astype(np.uint8), mode="L")
|
||||
expanded_mask_image = expanded_mask_image.resize(
|
||||
(
|
||||
mask_image.width * LATENT_SCALE_FACTOR,
|
||||
mask_image.height * LATENT_SCALE_FACTOR,
|
||||
),
|
||||
resample=Image.Resampling.NEAREST,
|
||||
)
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
# restore the original mask size
|
||||
dilated_mask = Image.fromarray(dilated_mask.astype(np.uint8))
|
||||
dilated_mask = dilated_mask.resize(
|
||||
(
|
||||
mask_image.width * LATENT_SCALE_FACTOR,
|
||||
mask_image.height * LATENT_SCALE_FACTOR,
|
||||
),
|
||||
resample=Image.Resampling.NEAREST,
|
||||
)
|
||||
|
||||
# stack the mask as a tensor, repeating 4 times on dimmension 1
|
||||
dilated_mask_tensor = image_resized_to_grid_as_tensor(dilated_mask, normalize=False)
|
||||
mask_name = context.tensors.save(tensor=dilated_mask_tensor.unsqueeze(0))
|
||||
|
||||
masked_latents_name = None
|
||||
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||
# all three fields must be present at the same time
|
||||
main_model_config = context.models.get_config(self.unet.unet.key)
|
||||
assert isinstance(main_model_config, MainConfigBase)
|
||||
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||
mask = blur_tensor
|
||||
mask = dilated_mask_tensor
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
@@ -137,3 +202,29 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
def max_filter2D_torch(self, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This morphological operation is much faster in torch than numpy or opencv
|
||||
For reasonable kernel sizes, the overhead of copying the data to the GPU is not worth it.
|
||||
"""
|
||||
h, w = kernel.shape
|
||||
pad_h, pad_w = h // 2, w // 2
|
||||
|
||||
padded = torch.nn.functional.pad(image, (pad_w, pad_w, pad_h, pad_h), mode="constant", value=0)
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
# This looks like it's inside out, but it does the same thing and is more efficient
|
||||
for i in range(h):
|
||||
for j in range(w):
|
||||
weight = kernel[i, j]
|
||||
if weight <= 0:
|
||||
continue
|
||||
|
||||
# Extract the region from padded tensor
|
||||
region = padded[i : i + image.shape[0], j : j + image.shape[1]]
|
||||
|
||||
# Apply weight and update max
|
||||
result = torch.maximum(result, region * weight)
|
||||
|
||||
return result
|
||||
|
||||
@@ -608,6 +608,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
mask=mask,
|
||||
method=single_ip_adapter.method,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -61,6 +61,10 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
Imagen3Model = "Imagen3ModelField"
|
||||
Imagen4Model = "Imagen4ModelField"
|
||||
ChatGPT4oModel = "ChatGPT4oModelField"
|
||||
FluxKontextModel = "FluxKontextModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -211,6 +215,7 @@ class FieldDescriptions:
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
flux_fill_conditioning = "FLUX Fill conditioning tensor"
|
||||
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
@@ -287,6 +292,12 @@ class FluxFillConditioningField(BaseModel):
|
||||
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
|
||||
|
||||
|
||||
class FluxKontextConditioningField(BaseModel):
|
||||
"""A conditioning field for FLUX Kontext (reference image)."""
|
||||
|
||||
image: ImageField = Field(description="The Kontext reference image.")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
@@ -398,8 +409,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
orig_required: bool = True
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
@@ -434,7 +445,7 @@ class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
@@ -496,7 +507,7 @@ def InputField(
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_hidden: Optional[bool] = None,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
) -> Any:
|
||||
@@ -532,15 +543,20 @@ def InputField(
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
if ui_type is not None:
|
||||
json_schema_extra_.ui_type = ui_type
|
||||
if ui_component is not None:
|
||||
json_schema_extra_.ui_component = ui_component
|
||||
if ui_hidden is not None:
|
||||
json_schema_extra_.ui_hidden = ui_hidden
|
||||
if ui_order is not None:
|
||||
json_schema_extra_.ui_order = ui_order
|
||||
if ui_choice_labels is not None:
|
||||
json_schema_extra_.ui_choice_labels = ui_choice_labels
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
@@ -570,7 +586,7 @@ def InputField(
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
logger.warning('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
@@ -612,7 +628,7 @@ def InputField(
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_unset=True),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,12 @@ from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxFillConditioningField,
|
||||
FluxKontextConditioningField,
|
||||
FluxReduxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
@@ -34,6 +33,7 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.3.0",
|
||||
version="4.0.0",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
@@ -145,11 +145,20 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
|
||||
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
|
||||
# upstream nodes.
|
||||
|
||||
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
|
||||
)
|
||||
|
||||
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference image).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -376,6 +385,27 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
kontext_extension = None
|
||||
if self.kontext_conditioning is not None:
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
|
||||
|
||||
kontext_extension = KontextExtension(
|
||||
context=context,
|
||||
kontext_conditioning=self.kontext_conditioning,
|
||||
vae_field=self.controlnet_vae,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
|
||||
# Prepare Kontext conditioning if provided
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
if kontext_extension is not None:
|
||||
# Ensure batch sizes match
|
||||
kontext_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
@@ -391,6 +421,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -865,7 +897,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
# The denoise function now handles Kontext conditioning correctly,
|
||||
# so we don't need to slice the latents here
|
||||
latents = state.latents.float()
|
||||
state.latents = unpack(latents, self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
|
||||
40
invokeai/app/invocations/flux_kontext.py
Normal file
40
invokeai/app/invocations/flux_kontext.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxKontextConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_kontext_output")
|
||||
class FluxKontextOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Kontext invocation."""
|
||||
|
||||
kontext_cond: FluxKontextConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_kontext",
|
||||
title="Kontext Conditioning - FLUX",
|
||||
tags=["conditioning", "kontext", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextInvocation(BaseInvocation):
|
||||
"""Prepares a reference image for FLUX Kontext conditioning."""
|
||||
|
||||
image: ImageField = InputField(description="The Kontext reference image.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
|
||||
"""Packages the provided image into a Kontext conditioning field."""
|
||||
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))
|
||||
@@ -3,6 +3,7 @@ from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -115,8 +116,14 @@ class FluxReduxInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
|
||||
siglip_model_config = self._get_siglip_model(context)
|
||||
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
|
||||
assert isinstance(siglip_pipeline, SigLipPipeline)
|
||||
with context.models.load(siglip_model_config.key).model_on_device() as (_, model):
|
||||
assert isinstance(model, SiglipVisionModel)
|
||||
|
||||
model_abs_path = context.models.get_absolute_path(siglip_model_config)
|
||||
processor = SiglipImageProcessor.from_pretrained(model_abs_path, local_files_only=True)
|
||||
assert isinstance(processor, SiglipImageProcessor)
|
||||
|
||||
siglip_pipeline = SigLipPipeline(processor, model)
|
||||
return siglip_pipeline.encode_image(
|
||||
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
@@ -111,6 +111,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_t5_tokenization(context, t5_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
@@ -151,6 +154,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
self._log_clip_tokenization(context, clip_tokenizer)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
@@ -170,3 +176,88 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _log_t5_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
|
||||
|
||||
# Tokenize the prompt using the same parameters as the model's text encoder.
|
||||
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=self.t5_max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True, # This is important for T5 to add the EOS token.
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("\u2581", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for token in tokens:
|
||||
if token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
used_tokens += 1
|
||||
elif token == tokenizer.pad_token:
|
||||
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
|
||||
continue
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
def _log_clip_tokenization(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
tokenizer: CLIPTokenizer,
|
||||
) -> None:
|
||||
"""Logs the tokenization of a prompt for a CLIP-based model."""
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokenized_output = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = tokenized_output.input_ids[0]
|
||||
attention_mask = tokenized_output.attention_mask[0]
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# The CLIP tokenizer uses '</w>' to denote spaces.
|
||||
# We'll replace it with a regular space for readability.
|
||||
tokens = [t.replace("</w>", " ") for t in tokens]
|
||||
|
||||
tokenized_str = ""
|
||||
used_tokens = 0
|
||||
for i, token in enumerate(tokens):
|
||||
if attention_mask[i] == 0:
|
||||
# Do not log padding tokens.
|
||||
continue
|
||||
|
||||
if token == tokenizer.bos_token:
|
||||
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
|
||||
elif token == tokenizer.eos_token:
|
||||
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
|
||||
else:
|
||||
color = (used_tokens % 6) + 1 # Cycle through 6 colors
|
||||
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
|
||||
used_tokens += 1
|
||||
|
||||
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
|
||||
context.logger.info(f"{tokenized_str}\x1b[0m")
|
||||
|
||||
@@ -21,14 +21,14 @@ class IdealSizeOutput(BaseInvocationOutput):
|
||||
"ideal_size",
|
||||
title="Ideal Size - SD1.5, SDXL",
|
||||
tags=["latents", "math", "ideal_size"],
|
||||
version="1.0.5",
|
||||
version="1.0.6",
|
||||
)
|
||||
class IdealSizeInvocation(BaseInvocation):
|
||||
"""Calculates the ideal size for generation to avoid duplication"""
|
||||
|
||||
width: int = InputField(default=1024, description="Final image width")
|
||||
height: int = InputField(default=576, description="Final image height")
|
||||
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet)
|
||||
multiplier: float = InputField(
|
||||
default=1.0,
|
||||
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
|
||||
|
||||
@@ -975,13 +975,13 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Canvas Paste Back",
|
||||
tags=["image", "combine"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
|
||||
|
||||
source_image: ImageField = InputField(description="The source image")
|
||||
target_image: ImageField = InputField(default=None, description="The target image")
|
||||
target_image: ImageField = InputField(description="The target image")
|
||||
mask: ImageField = InputField(
|
||||
description="The mask to use when pasting",
|
||||
)
|
||||
@@ -1218,12 +1218,15 @@ class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Add Image Noise",
|
||||
tags=["image", "noise"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
version="1.1.0",
|
||||
)
|
||||
class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add noise to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to add noise to")
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None, description="Optional mask determining where to apply noise (black=noise, white=no noise)"
|
||||
)
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
@@ -1267,12 +1270,27 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise = Image.fromarray(noise.astype(numpy.uint8), mode="RGB").resize(
|
||||
(image.width, image.height), Image.Resampling.NEAREST
|
||||
)
|
||||
|
||||
# Create a noisy version of the input image
|
||||
noisy_image = Image.blend(image.convert("RGB"), noise, self.amount).convert("RGBA")
|
||||
|
||||
# Paste back the alpha channel
|
||||
noisy_image.putalpha(alpha)
|
||||
# Apply mask if provided
|
||||
if self.mask is not None:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
image_dto = context.images.save(image=noisy_image)
|
||||
if mask_image.size != image.size:
|
||||
mask_image = mask_image.resize(image.size, Image.Resampling.LANCZOS)
|
||||
|
||||
result_image = image.copy()
|
||||
mask_image = ImageOps.invert(mask_image)
|
||||
result_image.paste(noisy_image, (0, 0), mask=mask_image)
|
||||
else:
|
||||
result_image = noisy_image
|
||||
|
||||
# Paste back the alpha channel from the original image
|
||||
result_image.putalpha(alpha)
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class IPAdapterField(BaseModel):
|
||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
|
||||
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
|
||||
method: str = Field(default="full", description="Weight apply method")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
@@ -94,7 +95,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
method: Literal["full", "style", "composition"] = InputField(
|
||||
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = InputField(
|
||||
default="full", description="The method to apply the IP-Adapter"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
@@ -147,6 +148,38 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
target_blocks = ["down_blocks.2.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "style_precise":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["up_blocks.1", "down_blocks.2", "mid_block"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "style_strong":
|
||||
if ip_adapter_info.base == "sd-1":
|
||||
target_blocks = ["up_blocks.0", "up_blocks.1", "up_blocks.2", "down_blocks.0", "down_blocks.1"]
|
||||
elif ip_adapter_info.base == "sdxl":
|
||||
target_blocks = [
|
||||
"up_blocks.0.attentions.1",
|
||||
"up_blocks.1.attentions.1",
|
||||
"up_blocks.2.attentions.1",
|
||||
"up_blocks.0.attentions.2",
|
||||
"up_blocks.1.attentions.2",
|
||||
"up_blocks.2.attentions.2",
|
||||
"up_blocks.0.attentions.0",
|
||||
"up_blocks.1.attentions.0",
|
||||
"up_blocks.2.attentions.0",
|
||||
"down_blocks.0.attentions.0",
|
||||
"down_blocks.0.attentions.1",
|
||||
"down_blocks.0.attentions.2",
|
||||
"down_blocks.1.attentions.0",
|
||||
"down_blocks.1.attentions.1",
|
||||
"down_blocks.1.attentions.2",
|
||||
"down_blocks.2.attentions.0",
|
||||
"down_blocks.2.attentions.2",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
|
||||
elif self.method == "full":
|
||||
target_blocks = ["block"]
|
||||
else:
|
||||
@@ -162,6 +195,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
mask=self.mask,
|
||||
method=self.method,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ from typing import Any
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic import field_validator
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@@ -54,10 +55,17 @@ class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
images = self._get_images(context)
|
||||
model_config = context.models.get_config(self.vllm_model)
|
||||
|
||||
with context.models.load(self.vllm_model) as vllm_model:
|
||||
assert isinstance(vllm_model, LlavaOnevisionModel)
|
||||
output = vllm_model.run(
|
||||
with context.models.load(self.vllm_model).model_on_device() as (_, model):
|
||||
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
|
||||
|
||||
model_abs_path = context.models.get_absolute_path(model_config)
|
||||
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
|
||||
assert isinstance(processor, LlavaOnevisionProcessor)
|
||||
|
||||
model = LlavaOnevisionPipeline(model, processor)
|
||||
output = model.run(
|
||||
prompt=self.prompt,
|
||||
images=images,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
|
||||
@@ -42,7 +42,9 @@ class IPAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
clip_vision_model: Literal["ViT-L", "ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = Field(
|
||||
description="Method to apply IP Weights with"
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers import AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
@@ -104,14 +104,13 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
sam_model = SamModel.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(sam_model, SamModel)
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
import uvicorn
|
||||
|
||||
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
|
||||
def get_app():
|
||||
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
|
||||
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
|
||||
@@ -18,9 +9,18 @@ def get_app():
|
||||
|
||||
def run_app() -> None:
|
||||
"""The main entrypoint for the app."""
|
||||
# Parse the CLI arguments.
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
# Parse the CLI arguments before doing anything else, which ensures CLI args correctly override settings from other
|
||||
# sources like `invokeai.yaml` or env vars.
|
||||
InvokeAIArgs.parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Load config.
|
||||
app_config = get_config()
|
||||
|
||||
@@ -32,6 +32,8 @@ def run_app() -> None:
|
||||
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
|
||||
|
||||
# This import must happen after configure_torch_cuda_allocator() is called, because the module imports torch.
|
||||
from invokeai.app.invocations.baseinvocation import InvocationRegistry
|
||||
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
@@ -66,6 +68,15 @@ def run_app() -> None:
|
||||
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
|
||||
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
|
||||
|
||||
# Check all invocations and ensure their outputs are registered.
|
||||
for invocation in InvocationRegistry.get_invocation_classes():
|
||||
invocation_type = invocation.get_type()
|
||||
output_annotation = invocation.get_output_annotation()
|
||||
if output_annotation not in InvocationRegistry.get_output_classes():
|
||||
logger.warning(
|
||||
f'Invocation "{invocation_type}" has unregistered output class "{output_annotation.__name__}"'
|
||||
)
|
||||
|
||||
if app_config.dev_reload:
|
||||
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
|
||||
# imported.
|
||||
|
||||
@@ -98,9 +98,18 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
|
||||
@@ -24,7 +24,6 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
@@ -93,7 +92,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
|
||||
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
@@ -176,7 +175,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
# GENERATION
|
||||
|
||||
@@ -241,6 +241,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
batch_status: BatchStatus = Field(description="The status of the batch")
|
||||
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
@@ -263,6 +264,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
credits=queue_item.credits,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@@ -97,3 +98,17 @@ class ImageRecordStorageBase(ABC):
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field, StrictBool, StrictStr
|
||||
from pydantic import BaseModel, Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
@@ -207,3 +207,16 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
starred=starred,
|
||||
has_workflow=has_workflow,
|
||||
)
|
||||
|
||||
|
||||
class ImageCollectionCounts(BaseModel):
|
||||
starred_count: int = Field(description="The number of starred images in the collection.")
|
||||
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
|
||||
|
||||
|
||||
class ImageNamesResult(BaseModel):
|
||||
"""Response containing ordered image names with metadata for optimistic updates."""
|
||||
|
||||
image_names: list[str] = Field(description="Ordered list of image names")
|
||||
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
|
||||
total_count: int = Field(description="Total number of images matching the query")
|
||||
|
||||
@@ -7,6 +7,7 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
IMAGE_DTO_COLS,
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@@ -196,9 +197,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND images.metadata LIKE ?
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
@@ -382,3 +387,96 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return None
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
|
||||
|
||||
@@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@@ -125,7 +126,7 @@ class ImageServiceABC(ABC):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -147,3 +148,17 @@ class ImageServiceABC(ABC):
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
"""Deletes all images on a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecord
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
@@ -39,3 +39,27 @@ def image_record_to_dto(
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
class ResultWithAffectedBoards(BaseModel):
|
||||
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
|
||||
|
||||
|
||||
class DeleteImagesResult(ResultWithAffectedBoards):
|
||||
deleted_images: list[str] = Field(description="The names of the images that were deleted")
|
||||
|
||||
|
||||
class StarredImagesResult(ResultWithAffectedBoards):
|
||||
starred_images: list[str] = Field(description="The names of the images that were starred")
|
||||
|
||||
|
||||
class UnstarredImagesResult(ResultWithAffectedBoards):
|
||||
unstarred_images: list[str] = Field(description="The names of the images that were unstarred")
|
||||
|
||||
|
||||
class AddImagesToBoardResult(ResultWithAffectedBoards):
|
||||
added_images: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(ResultWithAffectedBoards):
|
||||
removed_images: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
@@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import (
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@@ -78,7 +79,7 @@ class ImageService(ImageServiceABC):
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
|
||||
self.__invoker.services.logger.warning(f"Failed to add image to board {board_id}: {str(e)}")
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
|
||||
)
|
||||
@@ -309,3 +310,27 @@ class ImageService(ImageServiceABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting intermediates count")
|
||||
raise e
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image names")
|
||||
raise e
|
||||
|
||||
@@ -27,6 +27,10 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
|
||||
from invokeai.app.services.names.names_base import NameServiceBase
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
@@ -54,6 +58,8 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_relationships: "ModelRelationshipsServiceABC",
|
||||
model_relationship_records: "ModelRelationshipRecordStorageBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
@@ -81,6 +87,8 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.model_relationships = model_relationships
|
||||
self.model_relationship_records = model_relationship_records
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
|
||||
@@ -60,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self):
|
||||
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""Reset all stored statistics."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -73,9 +73,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||
self._stats.pop(graph_execution_state_id, None)
|
||||
self._cache_stats.pop(graph_execution_state_id, None)
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -148,7 +148,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _clear_pending_jobs(self) -> None:
|
||||
for job in self.list_jobs():
|
||||
if not job.in_terminal_state:
|
||||
self._logger.warning("Cancelling job {job.id}")
|
||||
self._logger.warning(f"Cancelling job {job.id}")
|
||||
self.cancel_job(job)
|
||||
while True:
|
||||
try:
|
||||
@@ -647,10 +647,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
hash_algo = self._app_config.hashing_algorithm
|
||||
fields = config.model_dump()
|
||||
|
||||
# WARNING!
|
||||
# The legacy probe relies on the implicit order of tests to determine model classification.
|
||||
# This can lead to regressions between the legacy and new probes.
|
||||
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
|
||||
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
|
||||
# Long-term fix: Improve `matches` to be more specific so that only one config matches
|
||||
# any given model - eliminating ambiguity and removing reliance on order.
|
||||
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
|
||||
try:
|
||||
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
|
||||
except InvalidModelConfigException:
|
||||
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
||||
except InvalidModelConfigException:
|
||||
return ModelConfigBase.classify(model_path, hash_algo, **fields)
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ModelRelationshipRecordStorageBase(ABC):
|
||||
"""Abstract base class for model-to-model relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
"""Creates a relationship between two models by keys."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
"""Removes a relationship between two models by keys."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
"""Gets all models keys related to a given model key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
"""Get related model keys for multiple models given a list of keys."""
|
||||
pass
|
||||
@@ -0,0 +1,66 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
@@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ModelRelationshipsServiceABC(ABC):
|
||||
"""High-level service for managing model-to-model relationships."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
"""Creates a relationship between two models keys."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
"""Removes a relationship between two models keys."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
"""Gets all models keys related to a given model key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
"""Get related model keys for multiple models."""
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from datetime import datetime
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class ModelRelationship(BaseModelExcludeNull):
|
||||
model_key_1: str
|
||||
model_key_2: str
|
||||
created_at: datetime
|
||||
@@ -0,0 +1,31 @@
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
|
||||
class ModelRelationshipsService(ModelRelationshipsServiceABC):
|
||||
__invoker: Invoker
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
self.__invoker.services.model_relationship_records.add_model_relationship(model_key_1, model_key_2)
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
self.__invoker.services.model_relationship_records.remove_model_relationship(model_key_1, model_key_2)
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
return self.__invoker.services.model_relationship_records.get_related_model_keys(model_key)
|
||||
|
||||
def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
|
||||
self.add_model_relationship(model_1.key, model_2.key)
|
||||
|
||||
def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
|
||||
self.remove_model_relationship(model_1.key, model_2.key)
|
||||
|
||||
def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]:
|
||||
return self.get_related_model_keys(model.key)
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
return self.__invoker.services.model_relationship_records.get_related_model_keys_batch(model_keys)
|
||||
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
@@ -210,7 +211,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
@@ -439,6 +440,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
# GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks.
|
||||
# Most queue items take seconds to execute, so the relative cost of a GC is very small.
|
||||
# Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak
|
||||
# allocation is well worth it.
|
||||
gc.collect()
|
||||
|
||||
self._invoker.services.logger.info(
|
||||
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
|
||||
)
|
||||
|
||||
@@ -10,6 +10,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
@@ -17,7 +19,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
@@ -92,6 +93,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_queue_item(self, item_id: int) -> None:
|
||||
"""Deletes a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fail_queue_item(
|
||||
self, item_id: int, error_type: str, error_message: str, error_traceback: str
|
||||
@@ -109,6 +115,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items with the given batch destination"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
"""Deletes all queue items with the given batch destination"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
@@ -119,6 +130,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items except in-progress items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
"""Deletes all queue items except in-progress items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_queue_items(
|
||||
self,
|
||||
@@ -127,10 +143,20 @@ class SessionQueueBase(ABC):
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
"""Gets a page of session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
|
||||
@@ -148,7 +148,7 @@ class Batch(BaseModel):
|
||||
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||
except NodeNotFoundError:
|
||||
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||
if batch_data.field_name not in node.model_fields:
|
||||
if batch_data.field_name not in type(node).model_fields:
|
||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||
return values
|
||||
|
||||
@@ -205,9 +205,10 @@ class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
user_label: str | None = Field(description="The user label of the field, if any")
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
class SessionQueueItem(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
item_id: int = Field(description="The identifier of the session queue item")
|
||||
@@ -251,41 +252,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
default=None,
|
||||
description="The ID of the published workflow associated with this queue item",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
return SessionQueueItemDTO(**queue_item_dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
"queue_id",
|
||||
"session_id",
|
||||
"priority",
|
||||
"session_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
|
||||
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
default=None, description="The workflow associated with this queue item"
|
||||
@@ -365,6 +332,7 @@ class EnqueueBatchResult(BaseModel):
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
priority: int = Field(description="The priority of the enqueued batch")
|
||||
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
|
||||
|
||||
|
||||
class RetryItemsResult(BaseModel):
|
||||
@@ -396,6 +364,18 @@ class CancelByDestinationResult(CancelByBatchIDsResult):
|
||||
pass
|
||||
|
||||
|
||||
class DeleteByDestinationResult(BaseModel):
|
||||
"""Result of deleting by a destination"""
|
||||
|
||||
deleted: int = Field(..., description="Number of queue items deleted")
|
||||
|
||||
|
||||
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
|
||||
"""Result of deleting all except current"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
@@ -24,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
@@ -46,10 +47,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
else:
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
@@ -104,11 +101,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
|
||||
|
||||
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
@@ -118,8 +111,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = calc_session_count(batch)
|
||||
values_to_insert = prepare_values_to_insert(
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
@@ -127,19 +124,28 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
if requested_count > enqueued_count:
|
||||
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE batch_id = ?
|
||||
ORDER BY item_id DESC;
|
||||
""",
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
@@ -147,6 +153,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
item_ids=item_ids,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
||||
return enqueue_result
|
||||
@@ -220,6 +227,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -331,6 +351,27 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
return queue_item
|
||||
|
||||
def delete_queue_item(self, item_id: int) -> None:
|
||||
"""Deletes a session queue item"""
|
||||
try:
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
return queue_item
|
||||
@@ -363,6 +404,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
cursor.execute(
|
||||
@@ -401,6 +444,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
cursor.execute(
|
||||
@@ -428,6 +473,71 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
raise
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
params = (queue_id, destination)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND destination = ?;
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND destination = ?;
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND status == 'pending'
|
||||
"""
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
@@ -438,6 +548,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id]
|
||||
cursor.execute(
|
||||
@@ -458,12 +570,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
@@ -543,26 +652,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -574,6 +669,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
@@ -589,7 +690,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
@@ -597,6 +698,37 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
cursor_ = self._conn.cursor()
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -611,7 +743,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] for row in counts_result)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
@@ -640,7 +772,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
total = sum(row[1] or 0 for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
@@ -672,7 +804,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] for row in counts_result)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
|
||||
return SessionQueueCountsByDestination(
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
ValidationError,
|
||||
@@ -57,17 +58,32 @@ class Edge(BaseModel):
|
||||
|
||||
|
||||
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_annotation())
|
||||
node_output_field = node_outputs.get(field) or None
|
||||
return node_output_field
|
||||
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
|
||||
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
|
||||
# would require some fairly significant changes and I don't want risk breaking anything.
|
||||
try:
|
||||
invocation_class = type(node)
|
||||
invocation_output_class = invocation_class.get_output_annotation()
|
||||
field_info = invocation_output_class.model_fields.get(field)
|
||||
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
|
||||
output_field_type = field_info.annotation
|
||||
return output_field_type
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_inputs = get_type_hints(node_type)
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
|
||||
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
|
||||
# would require some fairly significant changes and I don't want risk breaking anything.
|
||||
try:
|
||||
invocation_class = type(node)
|
||||
field_info = invocation_class.model_fields.get(field)
|
||||
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
|
||||
input_field_type = field_info.annotation
|
||||
return input_field_type
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
@@ -424,7 +440,7 @@ class Graph(BaseModel):
|
||||
)
|
||||
|
||||
# input fields are on the node
|
||||
if edge.destination.field not in destination_node.model_fields:
|
||||
if edge.destination.field not in type(destination_node).model_fields:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||
)
|
||||
@@ -787,6 +803,22 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
"executed",
|
||||
"executed_history",
|
||||
"results",
|
||||
"errors",
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
@@ -975,10 +1007,11 @@ class GraphExecutionState(BaseModel):
|
||||
new_node_ids = []
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(
|
||||
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
|
||||
)
|
||||
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||
all_iteration_mappings = []
|
||||
for source_node_id in next_node_parents:
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
||||
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
|
||||
|
||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
|
||||
@@ -21,6 +21,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.step_callback import diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
ModelConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
@@ -543,6 +544,30 @@ class ModelsInterface(InvocationContextInterface):
|
||||
self._util.signal_progress(f"Loading model {source}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path:
|
||||
"""Gets the absolute path for a given model config or path.
|
||||
|
||||
For example, if the model's path is `flux/main/FLUX Dev.safetensors`, and the models path is
|
||||
`/home/username/InvokeAI/models`, this method will return
|
||||
`/home/username/InvokeAI/models/flux/main/FLUX Dev.safetensors`.
|
||||
|
||||
Args:
|
||||
config_or_path: The model config or path.
|
||||
|
||||
Returns:
|
||||
The absolute path to the model.
|
||||
"""
|
||||
|
||||
model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path)
|
||||
|
||||
if model_path.is_absolute():
|
||||
return model_path.resolve()
|
||||
|
||||
base_models_path = self._services.configuration.models_path
|
||||
joined_path = base_models_path / model_path
|
||||
resolved_path = joined_path.resolve()
|
||||
return resolved_path
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
|
||||
@@ -22,6 +22,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -61,6 +62,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_17())
|
||||
migrator.register_migration(build_migration_18())
|
||||
migrator.register_migration(build_migration_19(app_config=config))
|
||||
migrator.register_migration(build_migration_20())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration20Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""
|
||||
-- many-to-many relationship table for models
|
||||
CREATE TABLE IF NOT EXISTS model_relationships (
|
||||
-- model_key_1 and model_key_2 are the same as the key(primary key) in the models table
|
||||
model_key_1 TEXT NOT NULL,
|
||||
model_key_2 TEXT NOT NULL,
|
||||
created_at TEXT DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
PRIMARY KEY (model_key_1, model_key_2),
|
||||
-- model_key_1 < model_key_2, to ensure uniqueness and prevent duplicates
|
||||
FOREIGN KEY (model_key_1) REFERENCES models(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (model_key_2) REFERENCES models(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
-- Creates an index to keep performance equal when searching for model_key_1 or model_key_2
|
||||
CREATE INDEX IF NOT EXISTS keyx_model_relationships_model_key_2
|
||||
ON model_relationships(model_key_2)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_20() -> Migration:
|
||||
return Migration(
|
||||
from_version=19,
|
||||
to_version=20,
|
||||
callback=Migration20Callback(),
|
||||
)
|
||||
@@ -230,6 +230,86 @@ def heuristic_resize(np_img: np.ndarray[Any, Any], size: tuple[int, int]) -> np.
|
||||
return resized
|
||||
|
||||
|
||||
# precompute common kernels
|
||||
_KERNEL3 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
||||
# directional masks for NMS
|
||||
_DIRS = [
|
||||
np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], np.uint8),
|
||||
np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], np.uint8),
|
||||
np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], np.uint8),
|
||||
np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], np.uint8),
|
||||
]
|
||||
|
||||
|
||||
def heuristic_resize_fast(np_img: np.ndarray, size: tuple[int, int]) -> np.ndarray:
|
||||
h, w = np_img.shape[:2]
|
||||
# early exit
|
||||
if (w, h) == size:
|
||||
return np_img
|
||||
|
||||
# separate alpha channel
|
||||
img = np_img
|
||||
alpha = None
|
||||
if img.ndim == 3 and img.shape[2] == 4:
|
||||
alpha, img = img[:, :, 3], img[:, :, :3]
|
||||
|
||||
# build small sample for unique‐color & binary detection
|
||||
flat = img.reshape(-1, img.shape[-1])
|
||||
N = flat.shape[0]
|
||||
# include four corners to avoid missing extreme values
|
||||
corners = np.vstack([img[0, 0], img[0, w - 1], img[h - 1, 0], img[h - 1, w - 1]])
|
||||
cnt = min(N, 100_000)
|
||||
samp = np.vstack([corners, flat[np.random.choice(N, cnt, replace=False)]])
|
||||
uc = np.unique(samp, axis=0).shape[0]
|
||||
vmin, vmax = samp.min(), samp.max()
|
||||
|
||||
# detect binary edge map & one‐pixel‐edge case
|
||||
is_binary = uc == 2 and vmin < 16 and vmax > 240
|
||||
one_pixel_edge = False
|
||||
if is_binary:
|
||||
# single gray conversion
|
||||
gray0 = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
grad = cv2.morphologyEx(gray0, cv2.MORPH_GRADIENT, _KERNEL3)
|
||||
cnt_edge = cv2.countNonZero(grad)
|
||||
cnt_all = cv2.countNonZero((gray0 > 127).astype(np.uint8))
|
||||
one_pixel_edge = (2 * cnt_edge) > cnt_all
|
||||
|
||||
# choose interp for color/seg/grayscale
|
||||
area_new, area_old = size[0] * size[1], w * h
|
||||
if 2 < uc < 200: # segmentation map
|
||||
interp = cv2.INTER_NEAREST
|
||||
elif area_new < area_old:
|
||||
interp = cv2.INTER_AREA
|
||||
else:
|
||||
interp = cv2.INTER_CUBIC
|
||||
|
||||
# single resize pass on RGB
|
||||
resized = cv2.resize(img, size, interpolation=interp)
|
||||
|
||||
if is_binary:
|
||||
# convert to gray & apply NMS via C++ dilate
|
||||
gray_r = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
|
||||
nms = np.zeros_like(gray_r)
|
||||
for K in _DIRS:
|
||||
d = cv2.dilate(gray_r, K)
|
||||
mask = d == gray_r
|
||||
nms[mask] = gray_r[mask]
|
||||
|
||||
# threshold + thinning if needed
|
||||
_, bw = cv2.threshold(nms, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
out_bin = cv2.ximgproc.thinning(bw) if one_pixel_edge else bw
|
||||
# restore 3 channels
|
||||
resized = np.stack([out_bin] * 3, axis=2)
|
||||
|
||||
# restore alpha with same interp as RGB for consistency
|
||||
if alpha is not None:
|
||||
am = cv2.resize(alpha, size, interpolation=interp)
|
||||
am = (am > 127).astype(np.uint8) * 255
|
||||
resized = np.dstack((resized, am))
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
|
||||
# modified for InvokeAI
|
||||
@@ -244,7 +324,7 @@ def np_img_resize(
|
||||
np_img = normalize_image_channel_count(np_img)
|
||||
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
np_img = heuristic_resize(np_img, (w, h))
|
||||
np_img = heuristic_resize_fast(np_img, (w, h))
|
||||
np_img = clone_contiguous(np_img)
|
||||
return np_img_to_torch(np_img, device), np_img
|
||||
|
||||
@@ -265,7 +345,7 @@ def np_img_resize(
|
||||
# Inpaint hijack
|
||||
high_quality_border_color[3] = 255
|
||||
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
np_img = heuristic_resize_fast(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (h - new_h) // 2)
|
||||
pad_w = max(0, (w - new_w) // 2)
|
||||
@@ -275,7 +355,7 @@ def np_img_resize(
|
||||
return np_img_to_torch(np_img, device), np_img
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
k = max(k0, k1)
|
||||
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
np_img = heuristic_resize_fast(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (new_h - h) // 2)
|
||||
pad_w = max(0, (new_w - w) // 2)
|
||||
|
||||
@@ -12,6 +12,9 @@ from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFie
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||
@@ -61,6 +64,10 @@ def get_openapi_func(
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in InvocationRegistry.get_output_classes():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
# Remove output_metadata that is only used on back-end from the schema
|
||||
if "output_meta" in json_schema["properties"]:
|
||||
json_schema["properties"].pop("output_meta")
|
||||
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ def get_timestamp() -> int:
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
return datetime.datetime.now(datetime.timezone.utc).isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
|
||||
@@ -123,7 +123,11 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
|
||||
if total_steps == 0:
|
||||
return 0.0
|
||||
if order == 2:
|
||||
return floor(step / 2) / floor(total_steps / 2)
|
||||
# Prevent division by zero when total_steps is 1 or 2
|
||||
denominator = floor(total_steps / 2)
|
||||
if denominator == 0:
|
||||
return 0.0
|
||||
return floor(step / 2) / denominator
|
||||
# order == 1
|
||||
return step / total_steps
|
||||
|
||||
|
||||
@@ -30,8 +30,11 @@ def denoise(
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
# extra img tokens
|
||||
# extra img tokens (channel-wise)
|
||||
img_cond: torch.Tensor | None,
|
||||
# extra img tokens (sequence-wise) - for Kontext conditioning
|
||||
img_cond_seq: torch.Tensor | None = None,
|
||||
img_cond_seq_ids: torch.Tensor | None = None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -46,6 +49,10 @@ def denoise(
|
||||
)
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
# Store original sequence length for slicing predictions
|
||||
original_seq_len = img.shape[1]
|
||||
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
@@ -71,10 +78,26 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
|
||||
# Prepare input for model - concatenate fresh each step
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
|
||||
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
|
||||
if img_cond is not None:
|
||||
img_input = torch.cat((img_input, img_cond), dim=-1)
|
||||
|
||||
# Add sequence-wise conditioning (for Kontext)
|
||||
if img_cond_seq is not None:
|
||||
assert img_cond_seq_ids is not None, (
|
||||
"You need to provide either both or neither of the sequence conditioning"
|
||||
)
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
pred = model(
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
img=img_input,
|
||||
img_ids=img_input_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
@@ -88,6 +111,10 @@ def denoise(
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
# Slice prediction to only include the main image tokens
|
||||
if img_input_ids is not None:
|
||||
pred = pred[:, :original_seq_len]
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
|
||||
149
invokeai/backend/flux/extensions/kontext_extension.py
Normal file
149
invokeai/backend/flux/extensions/kontext_extension.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
|
||||
|
||||
def generate_img_ids_with_offset(
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
idx_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with an optional offset.
|
||||
|
||||
Args:
|
||||
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
|
||||
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
|
||||
batch_size (int): Number of images in the batch.
|
||||
device (torch.device): Device to create tensors on.
|
||||
dtype (torch.dtype): Data type for the tensors.
|
||||
idx_offset (int): Offset to add to the first dimension of the image ids.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
|
||||
"""
|
||||
|
||||
if device.type == "mps":
|
||||
orig_dtype = dtype
|
||||
dtype = torch.float16
|
||||
|
||||
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
|
||||
# The 3 channels represent: [batch_offset, y_position, x_position]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
|
||||
|
||||
# Set the batch offset for all positions
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
# Create y-coordinate indices (vertical positions)
|
||||
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
|
||||
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
|
||||
img_ids[..., 1] = y_indices[:, None]
|
||||
|
||||
# Create x-coordinate indices (horizontal positions)
|
||||
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
|
||||
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
|
||||
img_ids[..., 2] = x_indices[None, :]
|
||||
|
||||
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
img_ids = img_ids.to(orig_dtype)
|
||||
|
||||
return img_ids
|
||||
|
||||
|
||||
class KontextExtension:
|
||||
"""Applies FLUX Kontext (reference image) conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kontext_conditioning: FluxKontextConditioningField,
|
||||
context: InvocationContext,
|
||||
vae_field: VAEField,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Initializes the KontextExtension, pre-processing the reference image
|
||||
into latents and positional IDs.
|
||||
"""
|
||||
self._context = context
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
self._vae_field = vae_field
|
||||
self.kontext_conditioning = kontext_conditioning
|
||||
|
||||
# Pre-process and cache the kontext latents and ids upon initialization.
|
||||
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
|
||||
|
||||
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encodes the reference image and prepares its latents and IDs."""
|
||||
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
|
||||
|
||||
# Calculate aspect ratio of input image
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Find the closest preferred resolution by aspect ratio
|
||||
_, target_width, target_height = min(
|
||||
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
|
||||
)
|
||||
|
||||
# Apply BFL's scaling formula
|
||||
# This ensures compatibility with the model's training
|
||||
scaled_width = 2 * int(target_width / 16)
|
||||
scaled_height = 2 * int(target_height / 16)
|
||||
|
||||
# Resize to the exact resolution used during training
|
||||
image = image.convert("RGB")
|
||||
final_width = 8 * scaled_width
|
||||
final_height = 8 * scaled_height
|
||||
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to tensor with same normalization as BFL
|
||||
image_np = np.array(image)
|
||||
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
|
||||
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
|
||||
image_tensor = image_tensor.to(self._device)
|
||||
|
||||
# Continue with VAE encoding
|
||||
vae_info = self._context.models.load(self._vae_field.vae)
|
||||
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
# Extract tensor dimensions
|
||||
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
|
||||
|
||||
# Pack the latents and generate IDs
|
||||
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
|
||||
kontext_ids = generate_img_ids_with_offset(
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
batch_size=batch_size,
|
||||
device=self._device,
|
||||
dtype=self._dtype,
|
||||
idx_offset=1,
|
||||
)
|
||||
|
||||
return kontext_latents_packed, kontext_ids
|
||||
|
||||
def ensure_batch_size(self, target_batch_size: int) -> None:
|
||||
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
|
||||
if self.kontext_latents.shape[0] != target_batch_size:
|
||||
self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1)
|
||||
self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1)
|
||||
@@ -174,11 +174,13 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
|
||||
dtype = torch.float16
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
# Set batch offset to 0 for main image tokens
|
||||
img_ids[..., 0] = 0
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
|
||||
if device.type == "mps":
|
||||
img_ids.to(orig_dtype)
|
||||
img_ids = img_ids.to(orig_dtype)
|
||||
|
||||
return img_ids
|
||||
|
||||
@@ -18,6 +18,29 @@ class ModelSpec:
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
# Preferred resolutions for Kontext models to avoid tiling artifacts
|
||||
# These are the specific resolutions the model was trained on
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-dev-fill": 512,
|
||||
|
||||
@@ -42,4 +42,5 @@ IP-Adapters:
|
||||
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
|
||||
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
|
||||
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
|
||||
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
|
||||
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
|
||||
- [InvokeAI/ip-adapter-plus_sdxl_vit-h](https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h)
|
||||
@@ -1,26 +1,15 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
from transformers import LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
|
||||
|
||||
|
||||
class LlavaOnevisionModel(RawModel):
|
||||
class LlavaOnevisionPipeline:
|
||||
"""A wrapper for a LLaVA Onevision model + processor."""
|
||||
|
||||
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
|
||||
self._vllm_model = vllm_model
|
||||
self._processor = processor
|
||||
|
||||
@classmethod
|
||||
def load_from_path(cls, path: str | Path):
|
||||
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
|
||||
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
|
||||
assert isinstance(processor, LlavaOnevisionProcessor)
|
||||
return cls(vllm_model, processor)
|
||||
|
||||
def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
|
||||
# TODO(ryand): Tune the max number of images that are useful for the model.
|
||||
if len(images) > 3:
|
||||
@@ -44,13 +33,3 @@ class LlavaOnevisionModel(RawModel):
|
||||
# The output_str will include the prompt, so we extract the response.
|
||||
response = output_str.split("assistant\n", 1)[1].strip()
|
||||
return response
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
self._vllm_model.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get size of the model in memory in bytes."""
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._vllm_model)
|
||||
@@ -37,6 +37,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
@@ -144,34 +145,37 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
|
||||
|
||||
_USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
_USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
USING_LEGACY_PROBE: ClassVar[set] = set()
|
||||
USING_CLASSIFY_API: ClassVar[set] = set()
|
||||
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, LegacyProbeMixin):
|
||||
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
|
||||
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
|
||||
else:
|
||||
ModelConfigBase._USING_CLASSIFY_API.add(cls)
|
||||
ModelConfigBase.USING_CLASSIFY_API.add(cls)
|
||||
|
||||
@staticmethod
|
||||
def all_config_classes():
|
||||
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
|
||||
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
|
||||
concrete = {cls for cls in subclasses if not isabstract(cls)}
|
||||
return concrete
|
||||
|
||||
@staticmethod
|
||||
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
|
||||
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
|
||||
"""
|
||||
Returns the best matching ModelConfig instance from a model's file/folder path.
|
||||
Raises InvalidModelConfigException if no valid configuration is found.
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
candidates = ModelConfigBase._USING_CLASSIFY_API
|
||||
if isinstance(mod, Path | str):
|
||||
mod = ModelOnDisk(mod, hash_algo)
|
||||
|
||||
candidates = ModelConfigBase.USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
|
||||
mod = ModelOnDisk(model_path, hash_algo)
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
@@ -293,7 +297,7 @@ class LoRAConfigBase(ABC, BaseModel):
|
||||
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
||||
|
||||
sd = mod.load_state_dict(mod.path)
|
||||
value = flux_format_from_state_dict(sd)
|
||||
value = flux_format_from_state_dict(sd, mod.metadata())
|
||||
mod.cache[key] = value
|
||||
return value
|
||||
|
||||
@@ -331,6 +335,36 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
|
||||
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
|
||||
metadata = mod.metadata()
|
||||
return (
|
||||
metadata.get("modelspec.sai_model_spec")
|
||||
and metadata.get("ot_branch") == "omi_format"
|
||||
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
metadata = mod.metadata()
|
||||
architecture = metadata["modelspec.architecture"]
|
||||
|
||||
if architecture == stable_diffusion_xl_1_lora:
|
||||
base = BaseModelType.StableDiffusionXL
|
||||
elif architecture == flux_dev_1_lora:
|
||||
base = BaseModelType.Flux
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
|
||||
|
||||
return {"base": base}
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@@ -347,7 +381,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if type(key) is int:
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
@@ -600,6 +634,21 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
}
|
||||
|
||||
|
||||
class ApiModelConfig(MainConfigBase, ModelConfigBase):
|
||||
"""Model config for API-based models."""
|
||||
|
||||
format: Literal[ModelFormat.Api] = ModelFormat.Api
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
# API models are not stored on disk, so we can't match them.
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
raise NotImplementedError("API models are not parsed from disk.")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@@ -650,6 +699,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
@@ -667,6 +717,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
|
||||
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
|
||||
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
|
||||
Annotated[ApiModelConfig, ApiModelConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -2,6 +2,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
@@ -76,7 +78,15 @@ class CachedModelOnlyFullLoad:
|
||||
for k, v in self._cpu_state_dict.items():
|
||||
new_state_dict[k] = v.to(self._compute_device, copy=True)
|
||||
self._model.load_state_dict(new_state_dict, assign=True)
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
|
||||
if isinstance(check_for_gguf, GGMLTensor):
|
||||
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(True)
|
||||
self._model.to(self._compute_device)
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
|
||||
else:
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
@@ -92,7 +102,15 @@ class CachedModelOnlyFullLoad:
|
||||
|
||||
if self._cpu_state_dict is not None:
|
||||
self._model.load_state_dict(self._cpu_state_dict, assign=True)
|
||||
self._model.to(self._offload_device)
|
||||
|
||||
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
|
||||
if isinstance(check_for_gguf, GGMLTensor):
|
||||
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(True)
|
||||
self._model.to(self._offload_device)
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
|
||||
else:
|
||||
self._model.to(self._offload_device)
|
||||
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
|
||||
@@ -2,9 +2,10 @@ import gc
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from logging import Logger
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@@ -54,6 +55,39 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntrySnapshot:
|
||||
cache_key: str
|
||||
total_bytes: int
|
||||
current_vram_bytes: int
|
||||
|
||||
|
||||
class CacheMissCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
model_key: str,
|
||||
cache_snapshot: dict[str, CacheEntrySnapshot],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class CacheHitCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
model_key: str,
|
||||
cache_snapshot: dict[str, CacheEntrySnapshot],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class CacheModelsClearedCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
models_cleared: int,
|
||||
bytes_requested: int,
|
||||
bytes_freed: int,
|
||||
cache_snapshot: dict[str, CacheEntrySnapshot],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
@@ -144,6 +178,34 @@ class ModelCache:
|
||||
# - Requests to empty the cache from a separate thread
|
||||
self._lock = threading.RLock()
|
||||
|
||||
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
|
||||
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
|
||||
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
|
||||
|
||||
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
|
||||
self._on_cache_hit_callbacks.add(cb)
|
||||
|
||||
def unsubscribe() -> None:
|
||||
self._on_cache_hit_callbacks.discard(cb)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
|
||||
self._on_cache_miss_callbacks.add(cb)
|
||||
|
||||
def unsubscribe() -> None:
|
||||
self._on_cache_miss_callbacks.discard(cb)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
def on_cache_models_cleared(self, cb: CacheModelsClearedCallback) -> Callable[[], None]:
|
||||
self._on_cache_models_cleared_callbacks.add(cb)
|
||||
|
||||
def unsubscribe() -> None:
|
||||
self._on_cache_models_cleared_callbacks.discard(cb)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
@property
|
||||
@synchronized
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
@@ -195,6 +257,20 @@ class ModelCache:
|
||||
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)"
|
||||
)
|
||||
|
||||
@synchronized
|
||||
def _get_cache_snapshot(self) -> dict[str, CacheEntrySnapshot]:
|
||||
overview: dict[str, CacheEntrySnapshot] = {}
|
||||
for cache_key, cache_entry in self._cached_models.items():
|
||||
total_bytes = cache_entry.cached_model.total_bytes()
|
||||
current_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
overview[cache_key] = CacheEntrySnapshot(
|
||||
cache_key=cache_key,
|
||||
total_bytes=total_bytes,
|
||||
current_vram_bytes=current_vram_bytes,
|
||||
)
|
||||
|
||||
return overview
|
||||
|
||||
@synchronized
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
@@ -208,6 +284,8 @@ class ModelCache:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
for cb in self._on_cache_miss_callbacks:
|
||||
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
self._logger.debug(f"Cache miss: {key}")
|
||||
@@ -229,6 +307,8 @@ class ModelCache:
|
||||
self._cache_stack.append(key)
|
||||
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
for cb in self._on_cache_hit_callbacks:
|
||||
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
|
||||
return cache_entry
|
||||
|
||||
@synchronized
|
||||
@@ -649,6 +729,13 @@ class ModelCache:
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
for cb in self._on_cache_models_cleared_callbacks:
|
||||
cb(
|
||||
models_cleared=models_cleared,
|
||||
bytes_requested=bytes_needed,
|
||||
bytes_freed=ram_bytes_freed,
|
||||
cache_snapshot=self._get_cache_snapshot(),
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
@@ -13,6 +13,12 @@ from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
|
||||
# up matrix and down matrix have different ranks so we can't simply multiply them
|
||||
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
|
||||
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
|
||||
@@ -7,7 +7,14 @@ from typing import Optional
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForTextEncoding,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -139,7 +146,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
@@ -183,7 +190,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(
|
||||
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from transformers import LlavaOnevisionForConditionalGeneration
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
)
|
||||
@@ -23,6 +24,8 @@ class LlavaOnevisionModelLoader(ModelLoader):
|
||||
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = LlavaOnevisionModel.load_from_path(model_path)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
model_path, local_files_only=True, torch_dtype=self._torch_dtype
|
||||
)
|
||||
assert isinstance(model, LlavaOnevisionForConditionalGeneration)
|
||||
return model
|
||||
|
||||
@@ -13,6 +13,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.omi.omi import convert_from_omi
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
@@ -20,6 +21,10 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
lora_model_from_flux_aitoolkit_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
@@ -39,6 +44,8 @@ from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import l
|
||||
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
|
||||
@@ -73,12 +80,23 @@ class LoRALoader(ModelLoader):
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Strip 'bundle_emb' keys - these are unused and currently cause downstream errors.
|
||||
# To revisit later to determine if they're needed/useful.
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")}
|
||||
|
||||
# At the time of writing, we support the OMI standard for base models Flux and SDXL
|
||||
if config.format == ModelFormat.OMI and self._model_base in [
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.Flux,
|
||||
]:
|
||||
state_dict = convert_from_omi(state_dict, config.base) # type: ignore
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]:
|
||||
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
|
||||
# distributed as a single file without the associated metadata containing the alpha value. We chose
|
||||
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
|
||||
@@ -92,8 +110,10 @@ class LoRALoader(ModelLoader):
|
||||
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
raise ValueError("LoRA model is in unsupported FLUX format")
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from transformers import SiglipVisionModel
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.SigLIP, format=ModelFormat.Diffusers)
|
||||
@@ -23,6 +24,5 @@ class SigLIPModelLoader(ModelLoader):
|
||||
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = SigLipPipeline.load_from_path(model_path)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
model = SiglipVisionModel.from_pretrained(model_path, local_files_only=True, torch_dtype=self._torch_dtype)
|
||||
return model
|
||||
@@ -16,11 +16,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
@@ -51,8 +49,6 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
DepthAnythingPipeline,
|
||||
SigLipPipeline,
|
||||
LlavaOnevisionModel,
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
|
||||
@@ -62,11 +62,14 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
|
||||
# Handling for our special syntax - we only want the base HF `org/repo` here.
|
||||
repo_id = id.split("::")[0] or id
|
||||
while not model_info:
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=id, files_metadata=True, revision=variant)
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True, revision=variant)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownMetadataException(f"'{id}' not found. See trace for details.") from excp
|
||||
raise UnknownMetadataException(f"'{repo_id}' not found. See trace for details.") from excp
|
||||
except RevisionNotFoundError:
|
||||
if variant is None:
|
||||
raise
|
||||
@@ -75,14 +78,14 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
_, name = repo_id.split("/")
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||
url=hf_hub_url(repo_id, s.rfilename, revision=variant or "main"),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Optional, TypeAlias
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors import safe_open
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
@@ -35,12 +36,21 @@ class ModelOnDisk:
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self) -> set[Path]:
|
||||
def weight_files(self) -> set[Path]:
|
||||
if self.path.is_file():
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
|
||||
try:
|
||||
with safe_open(self.path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
assert isinstance(metadata, dict)
|
||||
return metadata
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
||||
if self.path.is_file():
|
||||
return None
|
||||
@@ -64,18 +74,7 @@ class ModelOnDisk:
|
||||
if path in sd_cache:
|
||||
return sd_cache[path]
|
||||
|
||||
if not path:
|
||||
components = list(self.component_paths())
|
||||
match components:
|
||||
case []:
|
||||
raise ValueError("No weight files found for this model")
|
||||
case [p]:
|
||||
path = p
|
||||
case ps if len(ps) >= 2:
|
||||
raise ValueError(
|
||||
f"Multiple weight files found for this model: {ps}. "
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
path = self.resolve_weight_file(path)
|
||||
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
@@ -94,3 +93,18 @@ class ModelOnDisk:
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
sd_cache[path] = state_dict
|
||||
return state_dict
|
||||
|
||||
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
|
||||
if not path:
|
||||
weight_files = list(self.weight_files())
|
||||
match weight_files:
|
||||
case []:
|
||||
raise ValueError("No weight files found for this model")
|
||||
case [p]:
|
||||
return p
|
||||
case ps if len(ps) >= 2:
|
||||
raise ValueError(
|
||||
f"Multiple weight files found for this model: {ps}. "
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
return path
|
||||
|
||||
7
invokeai/backend/model_manager/omi/__init__.py
Normal file
7
invokeai/backend/model_manager/omi/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from invokeai.backend.model_manager.omi.omi import convert_from_omi
|
||||
from invokeai.backend.model_manager.omi.vendor.model_spec.architecture import (
|
||||
flux_dev_1_lora,
|
||||
stable_diffusion_xl_1_lora,
|
||||
)
|
||||
|
||||
__all__ = ["flux_dev_1_lora", "stable_diffusion_xl_1_lora", "convert_from_omi"]
|
||||
21
invokeai/backend/model_manager/omi/omi.py
Normal file
21
invokeai/backend/model_manager/omi/omi.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from invokeai.backend.model_manager.model_on_disk import StateDict
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
|
||||
convert_flux_lora as omi_flux,
|
||||
)
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
|
||||
convert_lora_util as lora_util,
|
||||
)
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
|
||||
convert_sdxl_lora as omi_sdxl,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
|
||||
keyset = {
|
||||
BaseModelType.Flux: omi_flux.convert_flux_lora_key_sets(),
|
||||
BaseModelType.StableDiffusionXL: omi_sdxl.convert_sdxl_lora_key_sets(),
|
||||
}[base]
|
||||
source = "omi"
|
||||
target = "legacy_diffusers"
|
||||
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore
|
||||
0
invokeai/backend/model_manager/omi/vendor/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/lora/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/convert/lora/__init__.py
vendored
Normal file
20
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_clip.py
vendored
Normal file
20
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_clip.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
|
||||
|
||||
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("text_projection", "text_projection", parent=key_prefix)]
|
||||
|
||||
for k in map_prefix_range("text_model.encoder.layers", "text_model.encoder.layers", parent=key_prefix):
|
||||
keys += [LoraConversionKeySet("mlp.fc1", "mlp.fc1", parent=k)]
|
||||
keys += [LoraConversionKeySet("mlp.fc2", "mlp.fc2", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.k_proj", "self_attn.k_proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.out_proj", "self_attn.out_proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.q_proj", "self_attn.q_proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("self_attn.v_proj", "self_attn.v_proj", parent=k)]
|
||||
|
||||
return keys
|
||||
84
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_flux_lora.py
vendored
Normal file
84
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_flux_lora.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
|
||||
|
||||
|
||||
def __map_double_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("img_attn.qkv.0", "attn.to_q", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_attn.qkv.1", "attn.to_k", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_attn.qkv.2", "attn.to_v", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("txt_attn.qkv.0", "attn.add_q_proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_attn.qkv.1", "attn.add_k_proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_attn.qkv.2", "attn.add_v_proj", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("img_attn.proj", "attn.to_out.0", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_mlp.0", "ff.net.0.proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_mlp.2", "ff.net.2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_mod.lin", "norm1.linear", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("txt_attn.proj", "attn.to_add_out", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_mlp.0", "ff_context.net.0.proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_mlp.2", "ff_context.net.2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("txt_mod.lin", "norm1_context.linear", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_single_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("linear1.0", "attn.to_q", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("linear1.1", "attn.to_k", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("linear1.2", "attn.to_v", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("linear1.3", "proj_mlp", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("linear2", "proj_out", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("modulation.lin", "norm.linear", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
|
||||
keys += [
|
||||
LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)
|
||||
]
|
||||
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
|
||||
keys += [
|
||||
LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)
|
||||
]
|
||||
keys += [
|
||||
LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)
|
||||
]
|
||||
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("time_in.out_layer", "time_text_embed.timestep_embedder.linear_2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("img_in.proj", "x_embedder", parent=key_prefix)]
|
||||
|
||||
for k in map_prefix_range("double_blocks", "transformer_blocks", parent=key_prefix):
|
||||
keys += __map_double_transformer_block(k)
|
||||
|
||||
for k in map_prefix_range("single_blocks", "single_transformer_blocks", parent=key_prefix):
|
||||
keys += __map_single_transformer_block(k)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def convert_flux_lora_key_sets() -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
|
||||
keys += __map_transformer(LoraConversionKeySet("transformer", "lora_transformer"))
|
||||
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
|
||||
keys += map_t5(LoraConversionKeySet("t5", "lora_te2"))
|
||||
|
||||
return keys
|
||||
217
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_lora_util.py
vendored
Normal file
217
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_lora_util.py
vendored
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class LoraConversionKeySet:
|
||||
def __init__(
|
||||
self,
|
||||
omi_prefix: str,
|
||||
diffusers_prefix: str,
|
||||
legacy_diffusers_prefix: str | None = None,
|
||||
parent: Self | None = None,
|
||||
swap_chunks: bool = False,
|
||||
filter_is_last: bool | None = None,
|
||||
next_omi_prefix: str | None = None,
|
||||
next_diffusers_prefix: str | None = None,
|
||||
):
|
||||
if parent is not None:
|
||||
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
|
||||
self.diffusers_prefix = combine(parent.diffusers_prefix, diffusers_prefix)
|
||||
else:
|
||||
self.omi_prefix = omi_prefix
|
||||
self.diffusers_prefix = diffusers_prefix
|
||||
|
||||
if legacy_diffusers_prefix is None:
|
||||
self.legacy_diffusers_prefix = self.diffusers_prefix.replace(".", "_")
|
||||
elif parent is not None:
|
||||
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace(
|
||||
".", "_"
|
||||
)
|
||||
else:
|
||||
self.legacy_diffusers_prefix = legacy_diffusers_prefix
|
||||
|
||||
self.parent = parent
|
||||
self.swap_chunks = swap_chunks
|
||||
self.filter_is_last = filter_is_last
|
||||
self.prefix = parent
|
||||
|
||||
if next_omi_prefix is None and parent is not None:
|
||||
self.next_omi_prefix = parent.next_omi_prefix
|
||||
self.next_diffusers_prefix = parent.next_diffusers_prefix
|
||||
self.next_legacy_diffusers_prefix = parent.next_legacy_diffusers_prefix
|
||||
elif next_omi_prefix is not None and parent is not None:
|
||||
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
|
||||
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
|
||||
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace(
|
||||
".", "_"
|
||||
)
|
||||
elif next_omi_prefix is not None and parent is None:
|
||||
self.next_omi_prefix = next_omi_prefix
|
||||
self.next_diffusers_prefix = next_diffusers_prefix
|
||||
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace(".", "_")
|
||||
else:
|
||||
self.next_omi_prefix = None
|
||||
self.next_diffusers_prefix = None
|
||||
self.next_legacy_diffusers_prefix = None
|
||||
|
||||
def __get_omi(self, in_prefix: str, key: str) -> str:
|
||||
return self.omi_prefix + key.removeprefix(in_prefix)
|
||||
|
||||
def __get_diffusers(self, in_prefix: str, key: str) -> str:
|
||||
return self.diffusers_prefix + key.removeprefix(in_prefix)
|
||||
|
||||
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
|
||||
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
|
||||
|
||||
suffix = key[key.rfind(".") :]
|
||||
if suffix not in [".alpha", ".dora_scale"]: # some keys only have a single . in the suffix
|
||||
suffix = key[key.removesuffix(suffix).rfind(".") :]
|
||||
key = key.removesuffix(suffix)
|
||||
|
||||
return key.replace(".", "_") + suffix
|
||||
|
||||
def get_key(self, in_prefix: str, key: str, target: str) -> str:
|
||||
if target == "omi":
|
||||
return self.__get_omi(in_prefix, key)
|
||||
elif target == "diffusers":
|
||||
return self.__get_diffusers(in_prefix, key)
|
||||
elif target == "legacy_diffusers":
|
||||
return self.__get_legacy_diffusers(in_prefix, key)
|
||||
return key
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"omi: {self.omi_prefix}, diffusers: {self.diffusers_prefix}, legacy: {self.legacy_diffusers_prefix}"
|
||||
|
||||
|
||||
def combine(left: str, right: str) -> str:
|
||||
left = left.rstrip(".")
|
||||
right = right.lstrip(".")
|
||||
if left == "" or left is None:
|
||||
return right
|
||||
elif right == "" or right is None:
|
||||
return left
|
||||
else:
|
||||
return left + "." + right
|
||||
|
||||
|
||||
def map_prefix_range(
|
||||
omi_prefix: str,
|
||||
diffusers_prefix: str,
|
||||
parent: LoraConversionKeySet,
|
||||
) -> list[LoraConversionKeySet]:
|
||||
# 100 should be a safe upper bound. increase if it's not enough in the future
|
||||
return [
|
||||
LoraConversionKeySet(
|
||||
omi_prefix=f"{omi_prefix}.{i}",
|
||||
diffusers_prefix=f"{diffusers_prefix}.{i}",
|
||||
parent=parent,
|
||||
next_omi_prefix=f"{omi_prefix}.{i + 1}",
|
||||
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
|
||||
def __convert(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
source: str,
|
||||
target: str,
|
||||
) -> dict[str, Tensor]:
|
||||
out_states = {}
|
||||
|
||||
if source == target:
|
||||
return dict(state_dict)
|
||||
|
||||
# TODO: maybe replace with a non O(n^2) algorithm
|
||||
for key, tensor in state_dict.items():
|
||||
for key_set in key_sets:
|
||||
in_prefix = ""
|
||||
|
||||
if source == "omi":
|
||||
in_prefix = key_set.omi_prefix
|
||||
elif source == "diffusers":
|
||||
in_prefix = key_set.diffusers_prefix
|
||||
elif source == "legacy_diffusers":
|
||||
in_prefix = key_set.legacy_diffusers_prefix
|
||||
|
||||
if not key.startswith(in_prefix):
|
||||
continue
|
||||
|
||||
if key_set.filter_is_last is not None:
|
||||
next_prefix = None
|
||||
if source == "omi":
|
||||
next_prefix = key_set.next_omi_prefix
|
||||
elif source == "diffusers":
|
||||
next_prefix = key_set.next_diffusers_prefix
|
||||
elif source == "legacy_diffusers":
|
||||
next_prefix = key_set.next_legacy_diffusers_prefix
|
||||
|
||||
is_last = not any(k.startswith(next_prefix) for k in state_dict)
|
||||
if key_set.filter_is_last != is_last:
|
||||
continue
|
||||
|
||||
name = key_set.get_key(in_prefix, key, target)
|
||||
|
||||
can_swap_chunks = target == "omi" or source == "omi"
|
||||
if key_set.swap_chunks and name.endswith(".lora_up.weight") and can_swap_chunks:
|
||||
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
|
||||
tensor = torch.cat([chunk_1, chunk_0], dim=0)
|
||||
|
||||
out_states[name] = tensor
|
||||
|
||||
break # only map the first matching key set
|
||||
|
||||
return out_states
|
||||
|
||||
|
||||
def __detect_source(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> str:
|
||||
omi_count = 0
|
||||
diffusers_count = 0
|
||||
legacy_diffusers_count = 0
|
||||
|
||||
for key in state_dict:
|
||||
for key_set in key_sets:
|
||||
if key.startswith(key_set.omi_prefix):
|
||||
omi_count += 1
|
||||
if key.startswith(key_set.diffusers_prefix):
|
||||
diffusers_count += 1
|
||||
if key.startswith(key_set.legacy_diffusers_prefix):
|
||||
legacy_diffusers_count += 1
|
||||
|
||||
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
|
||||
return "omi"
|
||||
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
|
||||
return "diffusers"
|
||||
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
|
||||
return "legacy_diffusers"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def convert_to_omi(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> dict[str, Tensor]:
|
||||
source = __detect_source(state_dict, key_sets)
|
||||
return __convert(state_dict, key_sets, source, "omi")
|
||||
|
||||
|
||||
def convert_to_diffusers(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> dict[str, Tensor]:
|
||||
source = __detect_source(state_dict, key_sets)
|
||||
return __convert(state_dict, key_sets, source, "diffusers")
|
||||
|
||||
|
||||
def convert_to_legacy_diffusers(
|
||||
state_dict: dict[str, Tensor],
|
||||
key_sets: list[LoraConversionKeySet],
|
||||
) -> dict[str, Tensor]:
|
||||
source = __detect_source(state_dict, key_sets)
|
||||
return __convert(state_dict, key_sets, source, "legacy_diffusers")
|
||||
125
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_sdxl_lora.py
vendored
Normal file
125
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_sdxl_lora.py
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
|
||||
|
||||
def __map_unet_resnet_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("emb_layers.1", "time_emb_proj", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("in_layers.2", "conv1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("out_layers.3", "conv2", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("skip_connection", "conv_shortcut", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_attention_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("proj_in", "proj_in", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("proj_out", "proj_out", parent=key_prefix)]
|
||||
for k in map_prefix_range("transformer_blocks", "transformer_blocks", parent=key_prefix):
|
||||
keys += [LoraConversionKeySet("attn1.to_q", "attn1.to_q", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn1.to_k", "attn1.to_k", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn1.to_v", "attn1.to_v", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn1.to_out.0", "attn1.to_out.0", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_q", "attn2.to_q", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_k", "attn2.to_k", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_v", "attn2.to_v", parent=k)]
|
||||
keys += [LoraConversionKeySet("attn2.to_out.0", "attn2.to_out.0", parent=k)]
|
||||
keys += [LoraConversionKeySet("ff.net.0.proj", "ff.net.0.proj", parent=k)]
|
||||
keys += [LoraConversionKeySet("ff.net.2", "ff.net.2", parent=k)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_down_blocks(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.1", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("3.0.op", "0.downsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.1", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("6.0.op", "1.downsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("7.1", "2.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("8.1", "2.attentions.1", parent=key_prefix))
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_mid_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("0", "resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("1", "attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("2", "resnets.1", parent=key_prefix))
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet_up_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("0.0", "0.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("0.1", "0.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("1.1", "0.attentions.1", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.2", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("2.1", "0.attentions.2", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("2.2.conv", "0.upsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("3.0", "1.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("3.1", "1.attentions.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.1", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.2", parent=key_prefix))
|
||||
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.2", parent=key_prefix))
|
||||
keys += [LoraConversionKeySet("5.2.conv", "1.upsamplers.0.conv", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("6.0", "2.resnets.0", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.1", parent=key_prefix))
|
||||
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.2", parent=key_prefix))
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def __map_unet(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("input_blocks.0.0", "conv_in", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("time_embed.0", "time_embedding.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("time_embed.2", "time_embedding.linear_2", parent=key_prefix)]
|
||||
|
||||
keys += [LoraConversionKeySet("label_emb.0.0", "add_embedding.linear_1", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("label_emb.0.2", "add_embedding.linear_2", parent=key_prefix)]
|
||||
|
||||
keys += __map_unet_down_blocks(LoraConversionKeySet("input_blocks", "down_blocks", parent=key_prefix))
|
||||
keys += __map_unet_mid_block(LoraConversionKeySet("middle_block", "mid_block", parent=key_prefix))
|
||||
keys += __map_unet_up_block(LoraConversionKeySet("output_blocks", "up_blocks", parent=key_prefix))
|
||||
|
||||
keys += [LoraConversionKeySet("out.0", "conv_norm_out", parent=key_prefix)]
|
||||
keys += [LoraConversionKeySet("out.2", "conv_out", parent=key_prefix)]
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
def convert_sdxl_lora_key_sets() -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
|
||||
keys += __map_unet(LoraConversionKeySet("unet", "lora_unet"))
|
||||
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
|
||||
keys += map_clip(LoraConversionKeySet("clip_g", "lora_te2"))
|
||||
|
||||
return keys
|
||||
19
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_t5.py
vendored
Normal file
19
invokeai/backend/model_manager/omi/vendor/convert/lora/convert_t5.py
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
|
||||
LoraConversionKeySet,
|
||||
map_prefix_range,
|
||||
)
|
||||
|
||||
|
||||
def map_t5(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
|
||||
keys = []
|
||||
|
||||
for k in map_prefix_range("encoder.block", "encoder.block", parent=key_prefix):
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.k", "layer.0.SelfAttention.k", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.o", "layer.0.SelfAttention.o", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.q", "layer.0.SelfAttention.q", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.0.SelfAttention.v", "layer.0.SelfAttention.v", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_0", "layer.1.DenseReluDense.wi_0", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_1", "layer.1.DenseReluDense.wi_1", parent=k)]
|
||||
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wo", "layer.1.DenseReluDense.wo", parent=k)]
|
||||
|
||||
return keys
|
||||
0
invokeai/backend/model_manager/omi/vendor/model_spec/__init__.py
vendored
Normal file
0
invokeai/backend/model_manager/omi/vendor/model_spec/__init__.py
vendored
Normal file
31
invokeai/backend/model_manager/omi/vendor/model_spec/architecture.py
vendored
Normal file
31
invokeai/backend/model_manager/omi/vendor/model_spec/architecture.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
stable_diffusion_1_lora = "stable-diffusion-v1/lora"
|
||||
stable_diffusion_1_inpainting_lora = "stable-diffusion-v1-inpainting/lora"
|
||||
|
||||
stable_diffusion_2_512_lora = "stable-diffusion-v2-512/lora"
|
||||
stable_diffusion_2_768_v_lora = "stable-diffusion-v2-768-v/lora"
|
||||
stable_diffusion_2_depth_lora = "stable-diffusion-v2-depth/lora"
|
||||
stable_diffusion_2_inpainting_lora = "stable-diffusion-v2-inpainting/lora"
|
||||
|
||||
stable_diffusion_3_medium_lora = "stable-diffusion-v3-medium/lora"
|
||||
stable_diffusion_35_medium_lora = "stable-diffusion-v3.5-medium/lora"
|
||||
stable_diffusion_35_large_lora = "stable-diffusion-v3.5-large/lora"
|
||||
|
||||
stable_diffusion_xl_1_lora = "stable-diffusion-xl-v1-base/lora"
|
||||
stable_diffusion_xl_1_inpainting_lora = "stable-diffusion-xl-v1-base-inpainting/lora"
|
||||
|
||||
wuerstchen_2_lora = "wuerstchen-v2-prior/lora"
|
||||
stable_cascade_1_stage_a_lora = "stable-cascade-v1-stage-a/lora"
|
||||
stable_cascade_1_stage_b_lora = "stable-cascade-v1-stage-b/lora"
|
||||
stable_cascade_1_stage_c_lora = "stable-cascade-v1-stage-c/lora"
|
||||
|
||||
pixart_alpha_lora = "pixart-alpha/lora"
|
||||
pixart_sigma_lora = "pixart-sigma/lora"
|
||||
|
||||
flux_dev_1_lora = "Flux.1-dev/lora"
|
||||
flux_fill_dev_1_lora = "Flux.1-fill-dev/lora"
|
||||
|
||||
sana_lora = "sana/lora"
|
||||
|
||||
hunyuan_video_lora = "hunyuan-video/lora"
|
||||
|
||||
hi_dream_i1_lora = "hidream-i1/lora"
|
||||
@@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies):
|
||||
dependencies: Optional[list[StarterModelWithoutDependencies]] = None
|
||||
|
||||
|
||||
class StarterModelBundles(BaseModel):
|
||||
class StarterModelBundle(BaseModel):
|
||||
name: str
|
||||
models: list[StarterModel]
|
||||
|
||||
@@ -109,7 +109,7 @@ flux_vae = StarterModel(
|
||||
|
||||
# region: Main
|
||||
flux_schnell_quantized = StarterModel(
|
||||
name="FLUX Schnell (Quantized)",
|
||||
name="FLUX.1 schnell (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
|
||||
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
@@ -117,7 +117,7 @@ flux_schnell_quantized = StarterModel(
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_dev_quantized = StarterModel(
|
||||
name="FLUX Dev (Quantized)",
|
||||
name="FLUX.1 dev (quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
|
||||
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
@@ -125,7 +125,7 @@ flux_dev_quantized = StarterModel(
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_schnell = StarterModel(
|
||||
name="FLUX Schnell",
|
||||
name="FLUX.1 schnell",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
|
||||
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
@@ -133,13 +133,29 @@ flux_schnell = StarterModel(
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_dev = StarterModel(
|
||||
name="FLUX Dev",
|
||||
name="FLUX.1 dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
|
||||
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext = StarterModel(
|
||||
name="FLUX.1 Kontext dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
|
||||
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext_quantized = StarterModel(
|
||||
name="FLUX.1 Kontext dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
@@ -297,6 +313,15 @@ ip_adapter_sdxl = StarterModel(
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
previous_names=["IP Adapter SDXL"],
|
||||
)
|
||||
ip_adapter_plus_sdxl = StarterModel(
|
||||
name="Precise Reference (IP Adapter Plus ViT-H)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h/resolve/main/ip-adapter-plus_sdxl_vit-h.safetensors",
|
||||
description="References images with a higher degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
previous_names=["IP Adapter Plus SDXL"],
|
||||
)
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
|
||||
base=BaseModelType.Flux,
|
||||
@@ -647,6 +672,7 @@ flux_fill = StarterModel(
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
flux_kontext_quantized,
|
||||
flux_schnell_quantized,
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
@@ -672,6 +698,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
ip_adapter_plus_sd1,
|
||||
ip_adapter_plus_face_sd1,
|
||||
ip_adapter_sdxl,
|
||||
ip_adapter_plus_sdxl,
|
||||
ip_adapter_flux,
|
||||
qr_code_cnet_sd1,
|
||||
qr_code_cnet_sdxl,
|
||||
@@ -744,6 +771,7 @@ sdxl_bundle: list[StarterModel] = [
|
||||
juggernaut_sdxl,
|
||||
sdxl_fp16_vae_fix,
|
||||
ip_adapter_sdxl,
|
||||
ip_adapter_plus_sdxl,
|
||||
canny_sdxl,
|
||||
depth_sdxl,
|
||||
softedge_sdxl,
|
||||
@@ -765,12 +793,13 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_depth_control_lora,
|
||||
flux_redux,
|
||||
flux_fill,
|
||||
flux_kontext_quantized,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
|
||||
BaseModelType.StableDiffusion1: sd1_bundle,
|
||||
BaseModelType.StableDiffusionXL: sdxl_bundle,
|
||||
BaseModelType.Flux: flux_bundle,
|
||||
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
|
||||
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
|
||||
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
|
||||
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
|
||||
}
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
@@ -26,7 +26,10 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
CogView4 = "cogview4"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
Imagen3 = "imagen3"
|
||||
Imagen4 = "imagen4"
|
||||
ChatGPT4o = "chatgpt-4o"
|
||||
FluxKontext = "flux-kontext"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
@@ -86,6 +89,7 @@ class ModelVariantType(str, Enum):
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
OMI = "omi"
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
LyCORIS = "lycoris"
|
||||
@@ -98,6 +102,7 @@ class ModelFormat(str, Enum):
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
Api = "api"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@@ -134,6 +139,7 @@ class FluxLoRAFormat(str, Enum):
|
||||
Kohya = "flux.kohya"
|
||||
OneTrainer = "flux.onetrainer"
|
||||
Control = "flux.control"
|
||||
AIToolkit = "flux.aitoolkit"
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
@@ -46,6 +46,10 @@ class ModelPatcher:
|
||||
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
if len(ti_list) == 0:
|
||||
yield tokenizer, TextualInversionManager(tokenizer)
|
||||
return
|
||||
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user