mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 06:38:08 -05:00
Compare commits
609 Commits
v4.2.9.dev
...
v4.2.9.dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c5abd44a7 | ||
|
|
765d99ac2f | ||
|
|
ac9a66a628 | ||
|
|
0ea88dc170 | ||
|
|
8369826d22 | ||
|
|
0e354f5164 | ||
|
|
41f2ee2633 | ||
|
|
4e74006c5f | ||
|
|
48edb6e023 | ||
|
|
aeae6af0a1 | ||
|
|
ab11d9af8e | ||
|
|
2e84327ca4 | ||
|
|
fa6842121c | ||
|
|
c402aa397d | ||
|
|
a58c8adc38 | ||
|
|
d43e2d690e | ||
|
|
284f768810 | ||
|
|
e933d1ae2b | ||
|
|
1e134de771 | ||
|
|
29c47c8be5 | ||
|
|
e1122c541d | ||
|
|
2f81d1ac83 | ||
|
|
56fbe751db | ||
|
|
93f1d67fbf | ||
|
|
9467b937ff | ||
|
|
4242e6e6c2 | ||
|
|
9b39452b3e | ||
|
|
85b23784cf | ||
|
|
085cc82926 | ||
|
|
0098c33f81 | ||
|
|
292e00ab68 | ||
|
|
6c1fb2d06e | ||
|
|
d60605fcd8 | ||
|
|
38ed720ff2 | ||
|
|
22203b8eb0 | ||
|
|
cf5fa792a1 | ||
|
|
c636633a8e | ||
|
|
55fe1ebc53 | ||
|
|
3c2fa6b475 | ||
|
|
9b927de2e0 | ||
|
|
6a62854e7d | ||
|
|
312093cbb0 | ||
|
|
06fe14e1fc | ||
|
|
1b54e58726 | ||
|
|
219d7c9611 | ||
|
|
9f742a669e | ||
|
|
41e324fd51 | ||
|
|
ce55a96125 | ||
|
|
64e60a7fde | ||
|
|
972f03960a | ||
|
|
5a403f087d | ||
|
|
fe59d7f3b0 | ||
|
|
b2b2b73aed | ||
|
|
20b563c4cb | ||
|
|
263a0ef5b4 | ||
|
|
e8723b7cd3 | ||
|
|
03e05b2068 | ||
|
|
6c0482a71d | ||
|
|
e6153e6fa4 | ||
|
|
6d209c6cc3 | ||
|
|
beb4e823dc | ||
|
|
61ba4c606b | ||
|
|
af840cedf3 | ||
|
|
0bf0bca03f | ||
|
|
e470eaf8f3 | ||
|
|
377db3f726 | ||
|
|
77f020a997 | ||
|
|
34e2eda625 | ||
|
|
e1d559db69 | ||
|
|
23a98e2ed6 | ||
|
|
fe3b2ed357 | ||
|
|
eedf81dcc5 | ||
|
|
dbef1a9e06 | ||
|
|
a41406ca9a | ||
|
|
f126a61f66 | ||
|
|
89c79276f3 | ||
|
|
423e463b95 | ||
|
|
52202e45de | ||
|
|
100832c66d | ||
|
|
a58b91b221 | ||
|
|
3af6d79852 | ||
|
|
1303e18e93 | ||
|
|
301da97670 | ||
|
|
17e76981bb | ||
|
|
9c1732e2bb | ||
|
|
a3179e7a3f | ||
|
|
f86b50d18a | ||
|
|
307885f505 | ||
|
|
4b49c1dd6b | ||
|
|
f917cefa84 | ||
|
|
bea98438fc | ||
|
|
17d3275086 | ||
|
|
059b7a0fcf | ||
|
|
05d3a989f6 | ||
|
|
590ae70c12 | ||
|
|
5240ec6e6f | ||
|
|
04772b642c | ||
|
|
65f6cb416f | ||
|
|
24c2028739 | ||
|
|
b0db9a3f56 | ||
|
|
3ea83574c0 | ||
|
|
05252a9bfc | ||
|
|
ce854f086e | ||
|
|
ff0c16978c | ||
|
|
41cc650031 | ||
|
|
c3f7554053 | ||
|
|
3f597a1c60 | ||
|
|
ccffdf1878 | ||
|
|
474089e892 | ||
|
|
778e8ad161 | ||
|
|
9f29892c24 | ||
|
|
56fd46a069 | ||
|
|
579e594861 | ||
|
|
af3440fbe3 | ||
|
|
cc101f55c4 | ||
|
|
ef1adf07f5 | ||
|
|
625c05d9be | ||
|
|
8ad3d8f738 | ||
|
|
4759875733 | ||
|
|
768e6a3c55 | ||
|
|
45bd85c039 | ||
|
|
9f94c5a8bd | ||
|
|
23fdd65961 | ||
|
|
8034195c30 | ||
|
|
08761127c9 | ||
|
|
4a10010b6c | ||
|
|
14cc5e2453 | ||
|
|
3d87adea60 | ||
|
|
36e8232ab6 | ||
|
|
72722a73be | ||
|
|
a09aa232a9 | ||
|
|
7ae8b64699 | ||
|
|
60e0d17f34 | ||
|
|
bf8bef2f00 | ||
|
|
b586d67bac | ||
|
|
31e5e5af13 | ||
|
|
94871e88cd | ||
|
|
00e56d1968 | ||
|
|
43672a53ab | ||
|
|
45097ed2a6 | ||
|
|
871f6b9f95 | ||
|
|
e6476e3c75 | ||
|
|
ac9b5f246d | ||
|
|
8bc72a2744 | ||
|
|
f76f1d89d7 | ||
|
|
7b54762b5e | ||
|
|
bc6faf6a6d | ||
|
|
e7ae1ac9b2 | ||
|
|
dcb436adb1 | ||
|
|
80f0441905 | ||
|
|
8cde803654 | ||
|
|
62445680ad | ||
|
|
7685e36886 | ||
|
|
4c196844bd | ||
|
|
b36159bda4 | ||
|
|
b02948d49a | ||
|
|
f442d206be | ||
|
|
21ed6bccd8 | ||
|
|
143ce7f00b | ||
|
|
28e716139b | ||
|
|
80a7c0c521 | ||
|
|
255ad3d2ad | ||
|
|
089bc9c7d8 | ||
|
|
ee7dafaf57 | ||
|
|
516ecdb0ee | ||
|
|
b77675f74d | ||
|
|
eea5c8efad | ||
|
|
09f1aac3a3 | ||
|
|
dd1dcb5eba | ||
|
|
757bd62ebe | ||
|
|
5a3127949b | ||
|
|
ced934c0a3 | ||
|
|
c32445084f | ||
|
|
9f1af0cdaa | ||
|
|
0d26cab400 | ||
|
|
c8de2da3fc | ||
|
|
ca089a105e | ||
|
|
22000918d6 | ||
|
|
6affc28da4 | ||
|
|
f659995e1c | ||
|
|
56fb3e738f | ||
|
|
56d450a907 | ||
|
|
d3cdcef36b | ||
|
|
19434e73b4 | ||
|
|
f7b3df9583 | ||
|
|
4da4b3bd50 | ||
|
|
e83513882a | ||
|
|
5adc784b6b | ||
|
|
f177513523 | ||
|
|
8ebcf79b1a | ||
|
|
c7e5f24704 | ||
|
|
ab3eb32ec8 | ||
|
|
d76509e5cb | ||
|
|
04f56aab82 | ||
|
|
c7913cbbbb | ||
|
|
0556468518 | ||
|
|
1c7ef827b6 | ||
|
|
5720ed4d64 | ||
|
|
7f05af4a68 | ||
|
|
6db615ed5a | ||
|
|
465f020c86 | ||
|
|
f05b77088f | ||
|
|
80a5abf1ad | ||
|
|
7a6e8de60f | ||
|
|
8364fa74cf | ||
|
|
14f4566dd0 | ||
|
|
6145378923 | ||
|
|
68e2606427 | ||
|
|
0f3eb04d1a | ||
|
|
4a355323b2 | ||
|
|
8601fbb4ea | ||
|
|
db885aa180 | ||
|
|
c18fb980a2 | ||
|
|
b630dbdf20 | ||
|
|
29ac1b5e01 | ||
|
|
506d3b079e | ||
|
|
0670e6b53a | ||
|
|
76124ea35b | ||
|
|
6eae3470cd | ||
|
|
c7ba7ac876 | ||
|
|
edc733abd9 | ||
|
|
a56ded664e | ||
|
|
31ace5fb0c | ||
|
|
11010236b3 | ||
|
|
5f061ac1e2 | ||
|
|
72919fa34e | ||
|
|
d5ca99fc3c | ||
|
|
e49b72ee4e | ||
|
|
abe8db8154 | ||
|
|
e0e5941384 | ||
|
|
86e1f4e8b0 | ||
|
|
447d873ef0 | ||
|
|
b21d613ce4 | ||
|
|
fc91adb32f | ||
|
|
71885db5fd | ||
|
|
b88d14b3df | ||
|
|
d98d35a8a8 | ||
|
|
87bc0ebd73 | ||
|
|
7b6ba3f690 | ||
|
|
b0d8948428 | ||
|
|
b32d681cee | ||
|
|
11a66d1d09 | ||
|
|
e41987f08c | ||
|
|
34b57ec188 | ||
|
|
d74843be31 | ||
|
|
1216c6f9c9 | ||
|
|
865b6017d3 | ||
|
|
922a021821 | ||
|
|
0b5f4cac57 | ||
|
|
c988c58c63 | ||
|
|
ceb8cbf59e | ||
|
|
52e9f43c46 | ||
|
|
4e5e7761fc | ||
|
|
9879999a65 | ||
|
|
bedaca70a3 | ||
|
|
2dd2225d2e | ||
|
|
d82031eec1 | ||
|
|
e5f2860b74 | ||
|
|
fa3560bb61 | ||
|
|
9b23f6ce30 | ||
|
|
5d6aa6cfd5 | ||
|
|
7d1819335f | ||
|
|
539e7a3f2d | ||
|
|
1686924ac8 | ||
|
|
556c1dc67b | ||
|
|
00f7093e65 | ||
|
|
79eb11dce9 | ||
|
|
0bf48c0d41 | ||
|
|
3f33e5f770 | ||
|
|
da3888ba9e | ||
|
|
a2f91b1055 | ||
|
|
d26095dfa1 | ||
|
|
83e786bd1e | ||
|
|
4cae12a507 | ||
|
|
d8e3708e0f | ||
|
|
f4de2fd3b1 | ||
|
|
e1cb30bbb4 | ||
|
|
97e0edc549 | ||
|
|
f4e66bf14f | ||
|
|
a6a7fe8aba | ||
|
|
a273f72560 | ||
|
|
b5126f45d6 | ||
|
|
ba3bb7cbf3 | ||
|
|
608279487b | ||
|
|
72b5374916 | ||
|
|
08b03212ca | ||
|
|
7e341a05a1 | ||
|
|
e665d08ee1 | ||
|
|
ba6362dc9d | ||
|
|
48f0797c43 | ||
|
|
640b0c4939 | ||
|
|
287c61e277 | ||
|
|
f7b2516109 | ||
|
|
b530eb49d4 | ||
|
|
fa94979ab6 | ||
|
|
54f2acf5b9 | ||
|
|
b6d845a4d0 | ||
|
|
1095b7c37f | ||
|
|
136ffd97ca | ||
|
|
80163d0af2 | ||
|
|
e1c6e926e7 | ||
|
|
2bb74abf31 | ||
|
|
0d4b91afe0 | ||
|
|
6c688d6878 | ||
|
|
243feecef9 | ||
|
|
abd22ba087 | ||
|
|
ab25546e97 | ||
|
|
925f0fca2a | ||
|
|
066366d885 | ||
|
|
61d52e96b7 | ||
|
|
051e88ca90 | ||
|
|
e873b69850 | ||
|
|
661fd55556 | ||
|
|
402f5a4717 | ||
|
|
81bf52ef37 | ||
|
|
8ff92796df | ||
|
|
68af60e12e | ||
|
|
cce6bf9428 | ||
|
|
078908fbea | ||
|
|
7275caaf5b | ||
|
|
d9487c1df4 | ||
|
|
3a9f955388 | ||
|
|
e46c7acd2e | ||
|
|
b771664851 | ||
|
|
7c21819d20 | ||
|
|
a57e618d47 | ||
|
|
c9849a79ea | ||
|
|
f1643fec08 | ||
|
|
951e63ca87 | ||
|
|
8e539c8a8c | ||
|
|
1e689a4902 | ||
|
|
7bbd25b5ec | ||
|
|
b1c7236117 | ||
|
|
ae3e473024 | ||
|
|
fd616f247c | ||
|
|
45dca2c821 | ||
|
|
40dc108c84 | ||
|
|
a421c25952 | ||
|
|
562d0afdbb | ||
|
|
2ce4698eef | ||
|
|
cb53108041 | ||
|
|
5fa65e5cc6 | ||
|
|
e8b0b6cef5 | ||
|
|
eca2712828 | ||
|
|
2804c0aede | ||
|
|
0429f0480d | ||
|
|
024759a0fc | ||
|
|
9a94aef2b0 | ||
|
|
e329cb45cd | ||
|
|
0dc38bd684 | ||
|
|
98ebca5f8c | ||
|
|
05cb3e03cf | ||
|
|
181132c149 | ||
|
|
a69aa00155 | ||
|
|
47d415e31c | ||
|
|
667a156817 | ||
|
|
00f39b977e | ||
|
|
e5776e2bd6 | ||
|
|
2b21f54897 | ||
|
|
678d12fcd5 | ||
|
|
03f06f611e | ||
|
|
6571e0f814 | ||
|
|
44f91026e1 | ||
|
|
56237328f1 | ||
|
|
ff68901e89 | ||
|
|
e0e7adb2b2 | ||
|
|
0923a5b128 | ||
|
|
75f8a84c79 | ||
|
|
af815cf7eb | ||
|
|
ef4d6c26f6 | ||
|
|
5087b306c0 | ||
|
|
a5708eaefe | ||
|
|
389bfc9e31 | ||
|
|
fd269e91e0 | ||
|
|
80136b0dfc | ||
|
|
9595eff1f9 | ||
|
|
c3c95754f7 | ||
|
|
22ab63fe8d | ||
|
|
5fefcab475 | ||
|
|
771a05b894 | ||
|
|
e2d8aaa923 | ||
|
|
0951aecb13 | ||
|
|
b1fe6f9853 | ||
|
|
551dd393aa | ||
|
|
78b4562184 | ||
|
|
c49b90e621 | ||
|
|
89e6233fbf | ||
|
|
3f9496c237 | ||
|
|
36e94af598 | ||
|
|
a181a684f5 | ||
|
|
bb712b3b3f | ||
|
|
e795de5647 | ||
|
|
bdc428cdd8 | ||
|
|
e4376e21dd | ||
|
|
77acc7baed | ||
|
|
9db1556c4d | ||
|
|
65de8b329b | ||
|
|
08dae5b047 | ||
|
|
8d2f056407 | ||
|
|
e66ef2e25e | ||
|
|
4d3ee7e082 | ||
|
|
fe48fda2f3 | ||
|
|
0f66753aa1 | ||
|
|
a18878474b | ||
|
|
0aa4568fd4 | ||
|
|
1de7e5760a | ||
|
|
135d6f2763 | ||
|
|
061767ede3 | ||
|
|
7204844bcb | ||
|
|
f2279ecadd | ||
|
|
75694869d2 | ||
|
|
d029680ac1 | ||
|
|
41c195d936 | ||
|
|
03ea005e9c | ||
|
|
6d936a7c44 | ||
|
|
fba17b93a6 | ||
|
|
73a7a27ea1 | ||
|
|
79287c2d16 | ||
|
|
662c5f4b77 | ||
|
|
7728ca6843 | ||
|
|
9607372f89 | ||
|
|
d27f948b78 | ||
|
|
b7aab81717 | ||
|
|
2998287f61 | ||
|
|
55d7f0ff5b | ||
|
|
4564f36d4a | ||
|
|
319de5c4e9 | ||
|
|
eee499faa3 | ||
|
|
63c5e42f2a | ||
|
|
bd16dc4479 | ||
|
|
49371ddec9 | ||
|
|
6a10d31b19 | ||
|
|
c951e733d3 | ||
|
|
7ed24cf847 | ||
|
|
821b7a0435 | ||
|
|
1b0344c412 | ||
|
|
03ca3c4b3d | ||
|
|
b939192b16 | ||
|
|
7ccf559a06 | ||
|
|
9eb091f873 | ||
|
|
3bd5521641 | ||
|
|
ced748e419 | ||
|
|
fbd137da9f | ||
|
|
03baebced6 | ||
|
|
cb19c1c370 | ||
|
|
788bad61d0 | ||
|
|
8f5f9bd44e | ||
|
|
2873e3e084 | ||
|
|
b004f17ae3 | ||
|
|
bea1e8c99b | ||
|
|
111493223f | ||
|
|
0a5ac2baec | ||
|
|
eec3c3b884 | ||
|
|
07b72c3d70 | ||
|
|
766e8c4eb0 | ||
|
|
57c257d10d | ||
|
|
d497da0e61 | ||
|
|
62310e7929 | ||
|
|
d79aa173a6 | ||
|
|
fbfdd3e003 | ||
|
|
a62b4a26ef | ||
|
|
817d4168c6 | ||
|
|
7e0a6d1538 | ||
|
|
ebc498ad19 | ||
|
|
b97b8c6ce6 | ||
|
|
b8abff65a1 | ||
|
|
a953dc1dbd | ||
|
|
a7c9848e99 | ||
|
|
73a1449eaf | ||
|
|
59f57ff542 | ||
|
|
e9204b87e3 | ||
|
|
7dd11bd60a | ||
|
|
275fc2ccf9 | ||
|
|
a2ef8d9d47 | ||
|
|
196779ff19 | ||
|
|
aee3147365 | ||
|
|
eaca940956 | ||
|
|
06006733e2 | ||
|
|
14d0bfbef6 | ||
|
|
0c9cf73702 | ||
|
|
3b864921ac | ||
|
|
f41539532f | ||
|
|
657009c254 | ||
|
|
c47e02c309 | ||
|
|
ce8a7bc178 | ||
|
|
488ca87787 | ||
|
|
d965df8ca9 | ||
|
|
995c26751e | ||
|
|
dd09723a2a | ||
|
|
5ff5af3ba2 | ||
|
|
4cb85404c0 | ||
|
|
50bc2f100d | ||
|
|
f65ce6a019 | ||
|
|
c28b635f2d | ||
|
|
e55896240d | ||
|
|
2b478ee7e1 | ||
|
|
69912a35ea | ||
|
|
9f1bd98c7e | ||
|
|
b531d6b7f0 | ||
|
|
8aa963fb81 | ||
|
|
b76e0ab4e4 | ||
|
|
aea03b4e92 | ||
|
|
b39e95966c | ||
|
|
d53e5e0158 | ||
|
|
0368dd651b | ||
|
|
84a4a1024e | ||
|
|
af4f258489 | ||
|
|
ddfc8785b4 | ||
|
|
d8515b6efc | ||
|
|
6a07f007a4 | ||
|
|
7a5a0c8075 | ||
|
|
5ed2e9b0fc | ||
|
|
aeb0a45eb6 | ||
|
|
21e814d766 | ||
|
|
cafc1839e2 | ||
|
|
e937aa831f | ||
|
|
890e6a95ed | ||
|
|
a5b7274359 | ||
|
|
172acf2cf5 | ||
|
|
b49fdf6407 | ||
|
|
5184d05bc2 | ||
|
|
7ef4553fc9 | ||
|
|
d6bd1e4a49 | ||
|
|
29413f20a7 | ||
|
|
04a44c8ea7 | ||
|
|
426f1b6f9a | ||
|
|
9c7f5ed321 | ||
|
|
4c37c7f280 | ||
|
|
a2d13cacbf | ||
|
|
aa127b83a3 | ||
|
|
e55192ae2a | ||
|
|
5159fcbc33 | ||
|
|
02ad7a0f93 | ||
|
|
bfa496e37f | ||
|
|
fdf347af26 | ||
|
|
0833dbb19d | ||
|
|
1b6bf58e58 | ||
|
|
5ead7bc7b4 | ||
|
|
f326d17856 | ||
|
|
908aa9beea | ||
|
|
4071e96245 | ||
|
|
b4daf29bd8 | ||
|
|
bf185339c2 | ||
|
|
df3abc75c2 | ||
|
|
28fc9a387c | ||
|
|
8533f207dc | ||
|
|
d135c48319 | ||
|
|
ca9090d070 | ||
|
|
93b185dc3b | ||
|
|
98e5efa895 | ||
|
|
c6774b829d | ||
|
|
22925f92bd | ||
|
|
302efcf6e8 | ||
|
|
76f9f90f0a | ||
|
|
5ba338e471 | ||
|
|
01f101c6f2 | ||
|
|
5606aec78d | ||
|
|
db90e1fe8b | ||
|
|
ae96c479f2 | ||
|
|
344ed2c83e | ||
|
|
1985944659 | ||
|
|
915357a6c1 | ||
|
|
63c34e78d7 | ||
|
|
366c460c1f | ||
|
|
40cab08133 | ||
|
|
51de25122a | ||
|
|
90313091db | ||
|
|
9982219d18 | ||
|
|
b3fe03b8f9 | ||
|
|
6edd15d68a | ||
|
|
0e2b328c88 | ||
|
|
25d7f9c316 | ||
|
|
3870ebdf29 | ||
|
|
7595d05191 | ||
|
|
21af727d79 | ||
|
|
5691829de6 | ||
|
|
20e6a57cf1 | ||
|
|
d0c40a8b5b | ||
|
|
f663215f25 | ||
|
|
7c5dea6d12 | ||
|
|
87261bdbc9 | ||
|
|
4e4b6c6dbc | ||
|
|
5e8cf9fb6a | ||
|
|
c738fe051f | ||
|
|
29fe1533f2 | ||
|
|
77090070bd | ||
|
|
6ba9b1b6b0 | ||
|
|
c578b8df1e | ||
|
|
cad9a41433 | ||
|
|
5fefb3b0f4 | ||
|
|
5284a870b0 | ||
|
|
e064377c05 | ||
|
|
3e569c8312 | ||
|
|
16825ee6e9 | ||
|
|
3f5340fa53 | ||
|
|
f2a1a39b33 | ||
|
|
326de55d3e | ||
|
|
b2df909570 | ||
|
|
026ac36b06 | ||
|
|
92125e5fd2 | ||
|
|
c0c139da88 | ||
|
|
404ad6a7fd | ||
|
|
fc39086fb4 | ||
|
|
cd215700fe | ||
|
|
e97fd85904 | ||
|
|
0a263fa5b1 | ||
|
|
fae3836a8d | ||
|
|
b3d2eb4178 | ||
|
|
576f1cbb75 |
@@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -128,6 +130,7 @@ class FieldDescriptions:
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
|
||||
@@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
@@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load T5.
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
@@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
@@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
return pooled_prompt_embeds
|
||||
|
||||
@@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
latents = self._run_diffusion(context)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
@@ -88,24 +90,19 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = prepare_latent_img_patches(x)
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
image_seq_len=x.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
@@ -140,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
|
||||
@@ -157,7 +157,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
@@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
model_key = self.model.key
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
if not context.models.exists(model_key):
|
||||
raise ValueError(f"Unknown model: {model_key}")
|
||||
transformer = self._get_model(context, SubModelType.Transformer)
|
||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||
vae = self._get_model(context, SubModelType.VAE)
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
@@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
match submodel:
|
||||
case SubModelType.Transformer:
|
||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
case SubModelType.VAE:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
SubModelType.VAE,
|
||||
"FLUX.1-schnell_ae",
|
||||
ModelType.VAE,
|
||||
BaseModelType.Flux,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
"clip-vit-large-patch14",
|
||||
ModelType.CLIPEmbed,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
self.t5_encoder.name,
|
||||
ModelType.T5Encoder,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||
|
||||
def _pull_model_from_mm(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
submodel: SubModelType,
|
||||
name: str,
|
||||
type: ModelType,
|
||||
base: BaseModelType,
|
||||
):
|
||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||
if len(models) != 1:
|
||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||
else:
|
||||
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
|
||||
@@ -88,7 +88,8 @@ class QueueItemEventBase(QueueEventBase):
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the queue item")
|
||||
destination: str | None = Field(default=None, description="The destination of the queue item")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@@ -114,6 +115,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -148,6 +150,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -186,6 +189,7 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -219,6 +223,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -257,6 +262,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
|
||||
@@ -77,7 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of this batch.")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
@@ -196,7 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
origin: str | None = Field(default=None, description="The origin of this queue item. ")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
@@ -297,6 +311,7 @@ class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
origin: str | None = Field(..., description="The origin of the batch")
|
||||
destination: str | None = Field(..., description="The destination of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
@@ -443,6 +458,7 @@ class SessionQueueValueToInsert(NamedTuple):
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
destination: str | None
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@@ -464,6 +480,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
batch.origin, # origin
|
||||
batch.destination, # destination
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -128,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@@ -579,7 +579,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -659,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -672,6 +673,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -681,6 +683,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
destination=destination,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
|
||||
@@ -10,9 +10,11 @@ class Migration15Callback:
|
||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
|
||||
|
||||
|
||||
def build_migration_15() -> Migration:
|
||||
@@ -21,6 +23,7 @@ def build_migration_15() -> Migration:
|
||||
|
||||
This migration does the following:
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
migration_15 = Migration(
|
||||
from_version=14,
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"version": "1.0.0",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
@@ -20,8 +20,8 @@
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"fieldName": "t5_encoder"
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
@@ -30,12 +30,12 @@
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.3",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
@@ -44,31 +44,25 @@
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
|
||||
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
|
||||
"name": "FLUX Dev (Quantized)",
|
||||
"base": "flux",
|
||||
"type": "main"
|
||||
}
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
|
||||
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
|
||||
"name": "t5_bnb_int8_quantized_encoder",
|
||||
"base": "any",
|
||||
"type": "t5_encoder"
|
||||
}
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 337.09365228062825,
|
||||
"y": 40.63469521079861
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -207,45 +201,45 @@
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
|
||||
@@ -111,16 +111,7 @@ def denoise(
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
dtype = model.txt_in.bias.dtype
|
||||
|
||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||
img = img.to(dtype=dtype)
|
||||
img_ids = img_ids.to(dtype=dtype)
|
||||
txt = txt.to(dtype=dtype)
|
||||
txt_ids = txt_ids.to(dtype=dtype)
|
||||
vec = vec.to(dtype=dtype)
|
||||
|
||||
# this is ignored for schnell
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
@@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor,
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
|
||||
@@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
|
||||
@@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -1,22 +1,6 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
@@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
@@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
||||
@@ -54,8 +54,10 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
@@ -89,6 +91,14 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||
new_state = bnb.MatmulLtState()
|
||||
new_state.threshold = self.state.threshold
|
||||
new_state.has_fp16_weights = False
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
|
||||
@@ -43,6 +43,11 @@ class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
|
||||
@@ -3,10 +3,9 @@ Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
||||
from invokeai.backend.util.util import Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
|
||||
@@ -7,9 +7,6 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
|
||||
@@ -164,10 +164,10 @@
|
||||
"alpha": "Alpha",
|
||||
"selected": "Selected",
|
||||
"tab": "Tab",
|
||||
"viewing": "Viewing",
|
||||
"viewingDesc": "Review images in a large gallery view",
|
||||
"editing": "Editing",
|
||||
"editingDesc": "Edit on the Control Layers canvas",
|
||||
"view": "View",
|
||||
"viewDesc": "Review images in a large gallery view",
|
||||
"edit": "Edit",
|
||||
"editDesc": "Edit on the Canvas",
|
||||
"comparing": "Comparing",
|
||||
"comparingDesc": "Comparing two images",
|
||||
"enabled": "Enabled",
|
||||
@@ -328,9 +328,13 @@
|
||||
"completedIn": "Completed in",
|
||||
"batch": "Batch",
|
||||
"origin": "Origin",
|
||||
"originCanvas": "Canvas",
|
||||
"originWorkflows": "Workflows",
|
||||
"originOther": "Other",
|
||||
"destination": "Destination",
|
||||
"upscaling": "Upscaling",
|
||||
"canvas": "Canvas",
|
||||
"generation": "Generation",
|
||||
"workflows": "Workflows",
|
||||
"other": "Other",
|
||||
"gallery": "Gallery",
|
||||
"batchFieldValues": "Batch Field Values",
|
||||
"item": "Item",
|
||||
"session": "Session",
|
||||
@@ -702,6 +706,8 @@
|
||||
"availableModels": "Available Models",
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipVision": "CLIP Vision",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@@ -789,6 +795,7 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
@@ -797,6 +804,7 @@
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"urlOrLocalPath": "URL or Local Path",
|
||||
@@ -1646,6 +1654,13 @@
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
},
|
||||
"controlLayers": {
|
||||
"saveCanvasToGallery": "Save Canvas To Gallery",
|
||||
"saveBboxToGallery": "Save Bbox To Gallery",
|
||||
"savedToGalleryOk": "Saved to Gallery",
|
||||
"savedToGalleryError": "Error saving to gallery",
|
||||
"mergeVisible": "Merge Visible",
|
||||
"mergeVisibleOk": "Merged visible layers",
|
||||
"mergeVisibleError": "Error merging visible layers",
|
||||
"clearHistory": "Clear History",
|
||||
"generateMode": "Generate",
|
||||
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||
@@ -1675,32 +1690,44 @@
|
||||
"deletePrompt": "Delete Prompt",
|
||||
"resetRegion": "Reset Region",
|
||||
"debugLayers": "Debug Layers",
|
||||
"showHUD": "Show HUD",
|
||||
"rectangle": "Rectangle",
|
||||
"maskFill": "Mask Fill",
|
||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||
"addRasterLayer": "Add $t(controlLayers.rasterLayer)",
|
||||
"addControlLayer": "Add $t(controlLayers.controlLayer)",
|
||||
"addInpaintMask": "Add $t(controlLayers.inpaintMask)",
|
||||
"addRegionalGuidance": "Add $t(controlLayers.regionalGuidance)",
|
||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||
"raster": "Raster",
|
||||
"rasterLayer_one": "Raster Layer",
|
||||
"controlLayer_one": "Control Layer",
|
||||
"inpaintMask_one": "Inpaint Mask",
|
||||
"regionalGuidance_one": "Regional Guidance",
|
||||
"ipAdapter_one": "IP Adapter",
|
||||
"rasterLayer_other": "Raster Layers",
|
||||
"controlLayer_other": "Control Layers",
|
||||
"inpaintMask_other": "Inpaint Masks",
|
||||
"regionalGuidance_other": "Regional Guidance",
|
||||
"ipAdapter_other": "IP Adapters",
|
||||
"rasterLayer": "Raster Layer",
|
||||
"controlLayer": "Control Layer",
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"sendToGallery": "Send To Gallery",
|
||||
"sendToGalleryDesc": "Generations will be sent to the gallery.",
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
"sendToCanvasDesc": "Generations will be staged onto the canvas.",
|
||||
"rasterLayer_withCount_one": "$t(controlLayers.rasterLayer)",
|
||||
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
|
||||
"inpaintMask_withCount_one": "$t(controlLayers.inpaintMask)",
|
||||
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
|
||||
"ipAdapter_withCount_one": "$t(controlLayers.ipAdapter)",
|
||||
"rasterLayer_withCount_other": "Raster Layers",
|
||||
"controlLayer_withCount_other": "Control Layers",
|
||||
"inpaintMask_withCount_other": "Inpaint Masks",
|
||||
"regionalGuidance_withCount_other": "Regional Guidance",
|
||||
"ipAdapter_withCount_other": "IP Adapters",
|
||||
"opacity": "Opacity",
|
||||
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||
"controlAdapters_withCount_hidden": "Control Adapters ({{count}} hidden)",
|
||||
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
|
||||
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||
"controlAdapters_withCount_visible": "Control Adapters ({{count}})",
|
||||
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
|
||||
@@ -1737,6 +1764,7 @@
|
||||
"flipHorizontal": "Flip Horizontal",
|
||||
"flipVertical": "Flip Vertical",
|
||||
"fill": {
|
||||
"fillColor": "Fill Color",
|
||||
"fillStyle": "Fill Style",
|
||||
"solid": "Solid",
|
||||
"grid": "Grid",
|
||||
|
||||
@@ -16,6 +16,7 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
@@ -43,10 +44,17 @@ interface Props {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
destination?: TabName | undefined;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: TabName;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(selectLanguage);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -85,6 +93,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||
}
|
||||
}, [dispatch, selectedStylePresetId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
|
||||
@@ -45,6 +45,7 @@ interface Props extends PropsWithChildren {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: TabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
@@ -66,6 +67,7 @@ const InvokeAIUI = ({
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@@ -227,6 +229,7 @@ const InvokeAIUI = ({
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
selectedStylePresetId={selectedStylePresetId}
|
||||
destination={destination}
|
||||
/>
|
||||
</AppDndContext>
|
||||
|
||||
@@ -68,7 +68,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
objects: [imageObject],
|
||||
};
|
||||
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: false }));
|
||||
api.dispatch(sessionStagingAreaReset());
|
||||
},
|
||||
});
|
||||
|
||||
@@ -31,7 +31,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
let didStartStaging = false;
|
||||
|
||||
if (!state.canvasSession.isStaging && state.canvasSession.mode === 'compose') {
|
||||
if (!state.canvasSession.isStaging && state.canvasSession.sendToCanvas) {
|
||||
dispatch(sessionStartedStaging());
|
||||
didStartStaging = true;
|
||||
}
|
||||
@@ -70,7 +70,11 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
const { g, noise, posCond } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond));
|
||||
const destination = state.canvasSession.sendToCanvas ? 'canvas' : 'gallery';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
|
||||
);
|
||||
|
||||
if (isErr(prepareBatchResult)) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
|
||||
@@ -32,6 +32,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
workflow: builtWorkflow,
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
},
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
@@ -16,7 +16,7 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
|
||||
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond, 'upscaling', 'gallery');
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
|
||||
104
invokeai/frontend/web/src/common/components/IconSwitch.tsx
Normal file
104
invokeai/frontend/web/src/common/components/IconSwitch.tsx
Normal file
@@ -0,0 +1,104 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Tooltip, useToken } from '@invoke-ai/ui-library';
|
||||
import type { ReactElement, ReactNode } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type IconSwitchProps = {
|
||||
isChecked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
iconChecked: ReactElement;
|
||||
tooltipChecked?: ReactNode;
|
||||
iconUnchecked: ReactElement;
|
||||
tooltipUnchecked?: ReactNode;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
const getSx = (padding: string | number): SystemStyleObject => ({
|
||||
transition: 'left 0.1s ease-in-out, transform 0.1s ease-in-out',
|
||||
'&[data-checked="true"]': {
|
||||
left: `calc(100% - ${padding})`,
|
||||
transform: 'translateX(-100%)',
|
||||
},
|
||||
'&[data-checked="false"]': {
|
||||
left: padding,
|
||||
transform: 'translateX(0)',
|
||||
},
|
||||
});
|
||||
|
||||
export const IconSwitch = memo(
|
||||
({
|
||||
isChecked,
|
||||
onChange,
|
||||
iconChecked,
|
||||
tooltipChecked,
|
||||
iconUnchecked,
|
||||
tooltipUnchecked,
|
||||
ariaLabel,
|
||||
}: IconSwitchProps) => {
|
||||
const onUncheck = useCallback(() => {
|
||||
onChange(false);
|
||||
}, [onChange]);
|
||||
const onCheck = useCallback(() => {
|
||||
onChange(true);
|
||||
}, [onChange]);
|
||||
|
||||
const gap = useToken('space', 1.5);
|
||||
const sx = useMemo(() => getSx(gap), [gap]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
h="full"
|
||||
p={gap}
|
||||
gap={gap}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
borderRadius="base"
|
||||
bg="invokeBlue.400"
|
||||
w={12}
|
||||
top={gap}
|
||||
bottom={gap}
|
||||
data-checked={isChecked}
|
||||
sx={sx}
|
||||
/>
|
||||
<Tooltip hasArrow label={tooltipUnchecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconUnchecked}
|
||||
onClick={onUncheck}
|
||||
variant={!isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={!isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={!isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip hasArrow label={tooltipChecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconChecked}
|
||||
onClick={onCheck}
|
||||
variant={isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IconSwitch.displayName = 'IconSwitch';
|
||||
@@ -1,52 +1,74 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useState } from 'react';
|
||||
|
||||
export const useBoolean = (initialValue: boolean) => {
|
||||
const [isTrue, set] = useState(initialValue);
|
||||
const setTrue = useCallback(() => set(true), []);
|
||||
const setFalse = useCallback(() => set(false), []);
|
||||
const toggle = useCallback(() => set((v) => !v), []);
|
||||
type UseBoolean = {
|
||||
isTrue: boolean;
|
||||
setTrue: () => void;
|
||||
setFalse: () => void;
|
||||
set: (value: boolean) => void;
|
||||
toggle: () => void;
|
||||
};
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
/**
|
||||
* Creates a hook to manage a boolean state. The boolean is stored in a nanostores atom.
|
||||
* Returns a tuple containing the hook and the atom. Use this for global boolean state.
|
||||
* @param initialValue Initial value of the boolean
|
||||
*/
|
||||
export const buildUseBoolean = (initialValue: boolean): [() => UseBoolean, WritableAtom<boolean>] => {
|
||||
const $boolean = atom(initialValue);
|
||||
|
||||
const setTrue = () => {
|
||||
$boolean.set(true);
|
||||
};
|
||||
const setFalse = () => {
|
||||
$boolean.set(false);
|
||||
};
|
||||
const set = (value: boolean) => {
|
||||
$boolean.set(value);
|
||||
};
|
||||
const toggle = () => {
|
||||
$boolean.set(!$boolean.get());
|
||||
};
|
||||
|
||||
const useBoolean = () => {
|
||||
const isTrue = useStore($boolean);
|
||||
|
||||
return {
|
||||
isTrue,
|
||||
set,
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
}),
|
||||
[isTrue, set, setTrue, setFalse, toggle]
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
return api;
|
||||
return [useBoolean, $boolean] as const;
|
||||
};
|
||||
|
||||
export const buildUseBoolean = ($boolean: WritableAtom<boolean>) => {
|
||||
return () => {
|
||||
const setTrue = useCallback(() => {
|
||||
$boolean.set(true);
|
||||
}, []);
|
||||
const setFalse = useCallback(() => {
|
||||
$boolean.set(false);
|
||||
}, []);
|
||||
const set = useCallback((value: boolean) => {
|
||||
$boolean.set(value);
|
||||
}, []);
|
||||
const toggle = useCallback(() => {
|
||||
$boolean.set(!$boolean.get());
|
||||
}, []);
|
||||
/**
|
||||
* Hook to manage a boolean state. Use this for a local boolean state.
|
||||
* @param initialValue Initial value of the boolean
|
||||
*/
|
||||
export const useBoolean = (initialValue: boolean) => {
|
||||
const [isTrue, set] = useState(initialValue);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
$boolean,
|
||||
}),
|
||||
[set, setFalse, setTrue, toggle]
|
||||
);
|
||||
const setTrue = useCallback(() => {
|
||||
set(true);
|
||||
}, [set]);
|
||||
const setFalse = useCallback(() => {
|
||||
set(false);
|
||||
}, [set]);
|
||||
const toggle = useCallback(() => {
|
||||
set((val) => !val);
|
||||
}, [set]);
|
||||
|
||||
return api;
|
||||
return {
|
||||
isTrue,
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addScope, removeScope, setScopes } from 'common/hooks/interactionScopes';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { useQueueBack } from 'features/queue/hooks/useQueueBack';
|
||||
import { useQueueFront } from 'features/queue/hooks/useQueueFront';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
@@ -31,22 +31,22 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<ButtonGroup orientation="vertical" isAttached={false}>
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center">
|
||||
<ButtonGroup position="relative" orientation="vertical" isAttached={false} top="20%">
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask', { count: 1 })}
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance', { count: 1 })}
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer', { count: 1 })}
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer', { count: 1 })}
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.ipAdapter', { count: 1 })}
|
||||
{t('controlLayers.ipAdapter')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
|
||||
@@ -33,19 +33,19 @@ export const CanvasEntityListMenuItems = memo(() => {
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask', { count: 1 })}
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance', { count: 1 })}
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer', { count: 1 })}
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer', { count: 1 })}
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.ipAdapter', { count: 1 })}
|
||||
{t('controlLayers.ipAdapter')}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -157,7 +157,7 @@ export const SelectedEntityOpacity = memo(() => {
|
||||
clampValueOnBlur={false}
|
||||
variant="outline"
|
||||
>
|
||||
<NumberInputField paddingInlineEnd={7} />
|
||||
<NumberInputField paddingInlineEnd={7} _focusVisible={{ zIndex: 0 }} />
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
aria-label="open-slider"
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import { Button, ButtonGroup } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasSessionSlice, sessionModeChanged } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectCanvasMode = createSelector(selectCanvasSessionSlice, (canvasSession) => canvasSession.mode);
|
||||
|
||||
export const CanvasModeSwitcher = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const mode = useAppSelector(selectCanvasMode);
|
||||
const onClickGenerate = useCallback(() => dispatch(sessionModeChanged({ mode: 'generate' })), [dispatch]);
|
||||
const onClickCompose = useCallback(() => dispatch(sessionModeChanged({ mode: 'compose' })), [dispatch]);
|
||||
|
||||
return (
|
||||
<ButtonGroup variant="outline">
|
||||
<Button onClick={onClickGenerate} colorScheme={mode === 'generate' ? 'invokeBlue' : 'base'}>
|
||||
{t('controlLayers.generateMode')}
|
||||
</Button>
|
||||
<Button onClick={onClickCompose} colorScheme={mode === 'compose' ? 'invokeBlue' : 'base'}>
|
||||
{t('controlLayers.composeMode')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasModeSwitcher.displayName = 'CanvasModeSwitcher';
|
||||
@@ -9,6 +9,7 @@ import { memo } from 'react';
|
||||
|
||||
export const CanvasPanelContent = memo(() => {
|
||||
const hasEntities = useAppSelector(selectHasEntities);
|
||||
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<Flex flexDir="column" gap={2} w="full" h="full">
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { IconSwitch } from 'common/components/IconSwitch';
|
||||
import { selectIsComposing, sessionSendToCanvasChanged } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold, PiPaintBrushBold } from 'react-icons/pi';
|
||||
|
||||
const TooltipSendToGallery = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToGallery')}</Text>
|
||||
<Text fontWeight="normal">{t('controlLayers.sendToGalleryDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipSendToGallery.displayName = 'TooltipSendToGallery';
|
||||
|
||||
const TooltipSendToCanvas = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToCanvas')}</Text>
|
||||
<Text fontWeight="normal">{t('controlLayers.sendToCanvasDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipSendToCanvas.displayName = 'TooltipSendToCanvas';
|
||||
|
||||
export const CanvasSendToToggle = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isComposing = useAppSelector(selectIsComposing);
|
||||
|
||||
const onChange = useCallback(
|
||||
(isChecked: boolean) => {
|
||||
dispatch(sessionSendToCanvasChanged(isChecked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IconSwitch
|
||||
isChecked={isComposing}
|
||||
onChange={onChange}
|
||||
iconUnchecked={<PiImageBold />}
|
||||
tooltipUnchecked={<TooltipSendToGallery />}
|
||||
iconChecked={<PiPaintBrushBold />}
|
||||
tooltipChecked={<TooltipSendToCanvas />}
|
||||
ariaLabel="Toggle canvas mode"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSendToToggle.displayName = 'CanvasSendToToggle';
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
@@ -29,8 +28,7 @@ export const ControlLayer = memo(({ id }: Props) => {
|
||||
<CanvasEntityEditableTitle />
|
||||
<Spacer />
|
||||
<ControlLayerBadges />
|
||||
<CanvasEntityIsLockedToggle />
|
||||
<CanvasEntityEnabledToggle />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<ControlLayerControlAdapter />
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasModeSwitcher } from 'features/controlLayers/components/CanvasModeSwitcher';
|
||||
import { CanvasResetViewButton } from 'features/controlLayers/components/CanvasResetViewButton';
|
||||
import { CanvasScale } from 'features/controlLayers/components/CanvasScale';
|
||||
import { SaveToGalleryButton } from 'features/controlLayers/components/SaveToGalleryButton';
|
||||
import { CanvasSettingsPopover } from 'features/controlLayers/components/Settings/CanvasSettingsPopover';
|
||||
import { ToolChooser } from 'features/controlLayers/components/Tool/ToolChooser';
|
||||
import { ToolFillColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker';
|
||||
import { ToolSettings } from 'features/controlLayers/components/Tool/ToolSettings';
|
||||
import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasUndoRedo } from 'features/controlLayers/hooks/useCanvasUndoRedo';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
|
||||
import { ViewerToggle } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const ControlLayersToolbar = memo(() => {
|
||||
useCanvasUndoRedo();
|
||||
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<Flex w="full" gap={2} alignItems="center">
|
||||
@@ -26,10 +28,9 @@ export const ControlLayersToolbar = memo(() => {
|
||||
<CanvasResetViewButton />
|
||||
<Spacer />
|
||||
<ToolFillColorPicker />
|
||||
<CanvasModeSwitcher />
|
||||
<UndoRedoButtonGroup />
|
||||
<SaveToGalleryButton />
|
||||
<CanvasSettingsPopover />
|
||||
<ViewerToggleMenu />
|
||||
<ViewerToggle />
|
||||
</Flex>
|
||||
</CanvasManagerProviderGate>
|
||||
);
|
||||
|
||||
@@ -1,53 +1,27 @@
|
||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { Grid, GridItem, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { round } from 'lodash-es';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectBbox = createSelector(selectCanvasSlice, (canvas) => canvas.bbox);
|
||||
|
||||
export const HeadsUpDisplay = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const stageAttrs = useStore(canvasManager.stateApi.$stageAttrs);
|
||||
const cursorPos = useStore(canvasManager.stateApi.$lastCursorPos);
|
||||
const isDrawing = useStore(canvasManager.stateApi.$isDrawing);
|
||||
const isMouseDown = useStore(canvasManager.stateApi.$isMouseDown);
|
||||
const lastMouseDownPos = useStore(canvasManager.stateApi.$lastMouseDownPos);
|
||||
const lastAddedPoint = useStore(canvasManager.stateApi.$lastAddedPoint);
|
||||
const bbox = useAppSelector(selectBbox);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" bg="blackAlpha.400" borderBottomEndRadius="base" p={2} minW={64} gap={2}>
|
||||
<HUDItem label="Zoom" value={`${round(stageAttrs.scale * 100, 2)}%`} />
|
||||
<HUDItem label="Stage Pos" value={`${round(stageAttrs.x, 3)}, ${round(stageAttrs.y, 3)}`} />
|
||||
<HUDItem
|
||||
label="Stage Size"
|
||||
value={`${round(stageAttrs.width / stageAttrs.scale, 2)}×${round(stageAttrs.height / stageAttrs.scale, 2)} px`}
|
||||
/>
|
||||
<HUDItem label="BBox Size" value={`${bbox.rect.width}×${bbox.rect.height} px`} />
|
||||
<HUDItem label="BBox Position" value={`${bbox.rect.x}, ${bbox.rect.y}`} />
|
||||
<HUDItem label="BBox Width % 8" value={round(bbox.rect.width % 8, 2)} />
|
||||
<HUDItem label="BBox Height % 8" value={round(bbox.rect.height % 8, 2)} />
|
||||
<HUDItem label="BBox X % 8" value={round(bbox.rect.x % 8, 2)} />
|
||||
<HUDItem label="BBox Y % 8" value={round(bbox.rect.y % 8, 2)} />
|
||||
<HUDItem
|
||||
label="Cursor Position"
|
||||
value={cursorPos ? `${round(cursorPos.x, 2)}, ${round(cursorPos.y, 2)}` : '?, ?'}
|
||||
/>
|
||||
<HUDItem label="Is Drawing" value={isDrawing ? 'True' : 'False'} />
|
||||
<HUDItem label="Is Mouse Down" value={isMouseDown ? 'True' : 'False'} />
|
||||
<HUDItem
|
||||
label="Last Mouse Down Pos"
|
||||
value={lastMouseDownPos ? `${round(lastMouseDownPos.x, 2)}, ${round(lastMouseDownPos.y, 2)}` : '?, ?'}
|
||||
/>
|
||||
<HUDItem
|
||||
label="Last Added Point"
|
||||
value={lastAddedPoint ? `${round(lastAddedPoint.x, 2)}, ${round(lastAddedPoint.y, 2)}` : '?, ?'}
|
||||
/>
|
||||
</Flex>
|
||||
<Grid
|
||||
bg="base.900"
|
||||
borderBottomEndRadius="base"
|
||||
p={2}
|
||||
gap={2}
|
||||
borderRadius="base"
|
||||
templateColumns="auto auto"
|
||||
opacity={0.6}
|
||||
>
|
||||
<HUDItem label="BBox" value={`${bbox.rect.width}×${bbox.rect.height} px`} />
|
||||
<HUDItem label="Scaled BBox" value={`${bbox.scaledSize.width}×${bbox.scaledSize.height} px`} />
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -55,12 +29,14 @@ HeadsUpDisplay.displayName = 'HeadsUpDisplay';
|
||||
|
||||
const HUDItem = memo(({ label, value }: { label: string; value: string | number }) => {
|
||||
return (
|
||||
<Box display="inline-block" lineHeight={1}>
|
||||
<Text as="span">{label}: </Text>
|
||||
<Text as="span" fontWeight="semibold">
|
||||
{value}
|
||||
</Text>
|
||||
</Box>
|
||||
<>
|
||||
<GridItem>
|
||||
<Text textAlign="end">{label}: </Text>
|
||||
</GridItem>
|
||||
<GridItem fontWeight="semibold">
|
||||
<Text>{value}</Text>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { IPAdapterSettings } from 'features/controlLayers/components/IPAdapter/IPAdapterSettings';
|
||||
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -18,10 +18,10 @@ export const IPAdapter = memo(({ id }: Props) => {
|
||||
return (
|
||||
<EntityIdentifierContext.Provider value={entityIdentifier}>
|
||||
<CanvasEntityContainer>
|
||||
<CanvasEntityHeader ps={4}>
|
||||
<CanvasEntityHeader ps={4} py={5}>
|
||||
<CanvasEntityEditableTitle />
|
||||
<Spacer />
|
||||
<CanvasEntityEnabledToggle />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
<IPAdapterSettings />
|
||||
</CanvasEntityContainer>
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { EntityMaskAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
@@ -25,8 +24,7 @@ export const InpaintMask = memo(({ id }: Props) => {
|
||||
<CanvasEntityPreviewImage />
|
||||
<CanvasEntityEditableTitle />
|
||||
<Spacer />
|
||||
<CanvasEntityIsLockedToggle />
|
||||
<CanvasEntityEnabledToggle />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
</CanvasEntityContainer>
|
||||
</EntityMaskAdapterGate>
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { EntityLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
@@ -25,8 +24,7 @@ export const RasterLayer = memo(({ id }: Props) => {
|
||||
<CanvasEntityPreviewImage />
|
||||
<CanvasEntityEditableTitle />
|
||||
<Spacer />
|
||||
<CanvasEntityIsLockedToggle />
|
||||
<CanvasEntityEnabledToggle />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
</CanvasEntityContainer>
|
||||
</EntityLayerAdapterGate>
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { RegionalGuidanceBadges } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceBadges';
|
||||
@@ -28,8 +27,7 @@ export const RegionalGuidance = memo(({ id }: Props) => {
|
||||
<CanvasEntityEditableTitle />
|
||||
<Spacer />
|
||||
<RegionalGuidanceBadges />
|
||||
<CanvasEntityIsLockedToggle />
|
||||
<CanvasEntityEnabledToggle />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
</CanvasEntityHeader>
|
||||
<RegionalGuidanceSettings />
|
||||
</CanvasEntityContainer>
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import { IconButton, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { isOk, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
const [useIsSaving] = buildUseBoolean(false);
|
||||
|
||||
export const SaveToGalleryButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const shift = useShiftModifier();
|
||||
const canvasManager = useCanvasManager();
|
||||
const isSaving = useIsSaving();
|
||||
|
||||
const onClick = useCallback(async () => {
|
||||
isSaving.setTrue();
|
||||
|
||||
const rect = shift ? canvasManager.stateApi.getBbox().rect : canvasManager.stage.getVisibleRect('raster_layer');
|
||||
|
||||
const result = await withResultAsync(() =>
|
||||
canvasManager.compositor.rasterizeAndUploadCompositeRasterLayer(rect, true)
|
||||
);
|
||||
|
||||
if (isOk(result)) {
|
||||
toast({ title: t('controlLayers.savedToGalleryOk') });
|
||||
} else {
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to save canvas to gallery');
|
||||
toast({ title: t('controlLayers.savedToGalleryError'), status: 'error' });
|
||||
}
|
||||
|
||||
isSaving.setFalse();
|
||||
}, [canvasManager.compositor, canvasManager.stage, canvasManager.stateApi, isSaving, shift, t]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
variant="ghost"
|
||||
onClick={onClick}
|
||||
icon={<PiFloppyDiskBold />}
|
||||
isLoading={isSaving.isTrue}
|
||||
aria-label={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')}
|
||||
tooltip={shift ? t('controlLayers.saveBboxToGallery') : t('controlLayers.saveCanvasToGallery')}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
SaveToGalleryButton.displayName = 'SaveToGalleryButton';
|
||||
@@ -18,6 +18,7 @@ import { CanvasSettingsInvertScrollCheckbox } from 'features/controlLayers/compo
|
||||
import { CanvasSettingsLogDebugInfoButton } from 'features/controlLayers/components/Settings/CanvasSettingsLogDebugInfo';
|
||||
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
|
||||
import { CanvasSettingsResetButton } from 'features/controlLayers/components/Settings/CanvasSettingsResetButton';
|
||||
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { RiSettings4Fill } from 'react-icons/ri';
|
||||
@@ -37,6 +38,7 @@ export const CanvasSettingsPopover = memo(() => {
|
||||
<CanvasSettingsInvertScrollCheckbox />
|
||||
<CanvasSettingsClipToBboxCheckbox />
|
||||
<CanvasSettingsDynamicGridSwitch />
|
||||
<CanvasSettingsShowHUDSwitch />
|
||||
<CanvasSettingsResetButton />
|
||||
<DebugSettings />
|
||||
</Flex>
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasSettingsSlice, settingsShowHUDToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectShowHUD = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.showHUD);
|
||||
|
||||
export const CanvasSettingsShowHUDSwitch = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const showHUD = useAppSelector(selectShowHUD);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsShowHUDToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.showHUD')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={showHUD} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsShowHUDSwitch.displayName = 'CanvasSettingsShowHUDSwitch';
|
||||
@@ -8,20 +8,18 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { HeadsUpDisplay } from 'features/controlLayers/components/HeadsUpDisplay';
|
||||
import { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { TRANSPARENCY_CHECKER_PATTERN } from 'features/controlLayers/konva/constants';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import Konva from 'konva';
|
||||
import { memo, useCallback, useEffect, useLayoutEffect, useState } from 'react';
|
||||
import { useDevicePixelRatio } from 'use-device-pixel-ratio';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
const showHud = false;
|
||||
|
||||
// This will log warnings when layers > 5 - maybe use `import.meta.env.MODE === 'development'` instead?
|
||||
Konva.showWarnings = false;
|
||||
|
||||
const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null, asPreview: boolean) => {
|
||||
const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null) => {
|
||||
const store = useAppStore();
|
||||
const socket = useStore($socket);
|
||||
const dpr = useDevicePixelRatio({ round: false });
|
||||
@@ -42,28 +40,25 @@ const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null,
|
||||
const manager = new CanvasManager(stage, container, store, socket);
|
||||
manager.initialize();
|
||||
return manager.destroy;
|
||||
}, [asPreview, container, socket, stage, store]);
|
||||
}, [container, socket, stage, store]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
Konva.pixelRatio = dpr;
|
||||
}, [dpr]);
|
||||
};
|
||||
|
||||
type Props = {
|
||||
asPreview?: boolean;
|
||||
};
|
||||
|
||||
const selectDynamicGrid = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.dynamicGrid);
|
||||
const selectShowHUD = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.showHUD);
|
||||
|
||||
export const StageComponent = memo(({ asPreview = false }: Props) => {
|
||||
export const StageComponent = memo(() => {
|
||||
const dynamicGrid = useAppSelector(selectDynamicGrid);
|
||||
const showHUD = useAppSelector(selectShowHUD);
|
||||
|
||||
const [stage] = useState(
|
||||
() =>
|
||||
new Konva.Stage({
|
||||
id: uuidv4(),
|
||||
id: getPrefixedId('konva_stage'),
|
||||
container: document.createElement('div'),
|
||||
listening: !asPreview,
|
||||
})
|
||||
);
|
||||
const [container, setContainer] = useState<HTMLDivElement | null>(null);
|
||||
@@ -72,7 +67,7 @@ export const StageComponent = memo(({ asPreview = false }: Props) => {
|
||||
setContainer(el);
|
||||
}, []);
|
||||
|
||||
useStageRenderer(stage, container, asPreview);
|
||||
useStageRenderer(stage, container);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
@@ -106,9 +101,9 @@ export const StageComponent = memo(({ asPreview = false }: Props) => {
|
||||
overflow="hidden"
|
||||
data-testid="control-layers-canvas"
|
||||
/>
|
||||
{!asPreview && (
|
||||
<Flex position="absolute" top={0} insetInlineStart={0} pointerEvents="none">
|
||||
{showHud && <HeadsUpDisplay />}
|
||||
{showHUD && (
|
||||
<Flex position="absolute" top={1} insetInlineStart={1} pointerEvents="none">
|
||||
<HeadsUpDisplay />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
@@ -98,9 +98,9 @@ export const StagingAreaToolbar = memo(() => {
|
||||
onPrev,
|
||||
{
|
||||
preventDefault: true,
|
||||
enabled: isCanvasActive,
|
||||
enabled: isCanvasActive && shouldShowStagedImage && imageCount > 1,
|
||||
},
|
||||
[isCanvasActive]
|
||||
[isCanvasActive, shouldShowStagedImage, imageCount]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -108,9 +108,9 @@ export const StagingAreaToolbar = memo(() => {
|
||||
onNext,
|
||||
{
|
||||
preventDefault: true,
|
||||
enabled: isCanvasActive,
|
||||
enabled: isCanvasActive && shouldShowStagedImage && imageCount > 1,
|
||||
},
|
||||
[isCanvasActive]
|
||||
[isCanvasActive, shouldShowStagedImage, imageCount]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -118,9 +118,9 @@ export const StagingAreaToolbar = memo(() => {
|
||||
onAccept,
|
||||
{
|
||||
preventDefault: true,
|
||||
enabled: isCanvasActive,
|
||||
enabled: isCanvasActive && shouldShowStagedImage && imageCount > 1,
|
||||
},
|
||||
[isCanvasActive]
|
||||
[isCanvasActive, shouldShowStagedImage, imageCount]
|
||||
);
|
||||
|
||||
const counterText = useMemo(() => {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Flex, Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Popover, PopoverBody, PopoverContent, PopoverTrigger, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
@@ -23,17 +23,13 @@ export const ToolFillColorPicker = memo(() => {
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<Flex
|
||||
as="button"
|
||||
aria-label={t('controlLayers.brushColor')}
|
||||
borderRadius="full"
|
||||
borderWidth={1}
|
||||
bg={rgbaColorToString(fill)}
|
||||
w={8}
|
||||
h={8}
|
||||
cursor="pointer"
|
||||
tabIndex={-1}
|
||||
/>
|
||||
<Flex role="button" aria-label={t('controlLayers.fill.fillColor')} tabIndex={-1} w={8} h={8}>
|
||||
<Tooltip label={t('controlLayers.fill.fillColor')}>
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Box borderRadius="full" w={6} h={6} borderWidth={1} bg={rgbaColorToString(fill)} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverBody minH={64}>
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
type: CanvasEntityIdentifier['type'];
|
||||
};
|
||||
|
||||
export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const onClick = useCallback(() => {
|
||||
switch (type) {
|
||||
case 'inpaint_mask':
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'regional_guidance':
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'raster_layer':
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'control_layer':
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'ip_adapter':
|
||||
dispatch(ipaAdded({ isSelected: true }));
|
||||
break;
|
||||
}
|
||||
}, [dispatch, type]);
|
||||
|
||||
const label = useMemo(() => {
|
||||
switch (type) {
|
||||
case 'inpaint_mask':
|
||||
return t('controlLayers.addInpaintMask');
|
||||
case 'regional_guidance':
|
||||
return t('controlLayers.addRegionalGuidance');
|
||||
case 'raster_layer':
|
||||
return t('controlLayers.addRasterLayer');
|
||||
case 'control_layer':
|
||||
return t('controlLayers.addControlLayer');
|
||||
case 'ip_adapter':
|
||||
return t('controlLayers.addIPAdapter');
|
||||
}
|
||||
}, [type, t]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label={label}
|
||||
tooltip={label}
|
||||
variant="link"
|
||||
icon={<PiPlusBold />}
|
||||
onClick={onClick}
|
||||
alignSelf="stretch"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityAddOfTypeButton.displayName = 'CanvasEntityAddOfTypeButton';
|
||||
@@ -0,0 +1,31 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { entityDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityDeleteButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(entityDeleted({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label={t('common.delete')}
|
||||
tooltip={t('common.delete')}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiTrashSimpleFill />}
|
||||
onClick={onClick}
|
||||
colorScheme="error"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityDeleteButton.displayName = 'CanvasEntityDeleteButton';
|
||||
@@ -21,7 +21,8 @@ export const CanvasEntityEnabledToggle = memo(() => {
|
||||
size="sm"
|
||||
aria-label={t(isEnabled ? 'common.enabled' : 'common.disabled')}
|
||||
tooltip={t(isEnabled ? 'common.enabled' : 'common.disabled')}
|
||||
variant="ghost"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={isEnabled ? <PiCircleFill /> : <PiCircleBold />}
|
||||
onClick={onClick}
|
||||
/>
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
|
||||
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
|
||||
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
|
||||
import { useEntityTypeTitle } from 'features/controlLayers/hooks/useEntityTypeTitle';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
type Props = PropsWithChildren<{
|
||||
@@ -20,6 +22,9 @@ const _hover: SystemStyleObject = {
|
||||
export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props) => {
|
||||
const title = useEntityTypeTitle(type);
|
||||
const collapse = useBoolean(true);
|
||||
const canMergeVisible = useMemo(() => type === 'raster_layer' || type === 'inpaint_mask', [type]);
|
||||
const canHideAll = useMemo(() => type !== 'ip_adapter', [type]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex w="full">
|
||||
@@ -53,7 +58,9 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props
|
||||
</Text>
|
||||
<Spacer />
|
||||
</Flex>
|
||||
{type !== 'ip_adapter' && <CanvasEntityTypeIsHiddenToggle type={type} />}
|
||||
{canMergeVisible && <CanvasEntityMergeVisibleButton type={type} />}
|
||||
<CanvasEntityAddOfTypeButton type={type} />
|
||||
{canHideAll && <CanvasEntityTypeIsHiddenToggle type={type} />}
|
||||
</Flex>
|
||||
<Collapse in={collapse.isTrue}>
|
||||
<Flex flexDir="column" gap={2} pt={2}>
|
||||
|
||||
@@ -56,7 +56,7 @@ export const CanvasEntityHeader = memo(({ children, ...rest }: FlexProps) => {
|
||||
}, [entityIdentifier]);
|
||||
|
||||
return (
|
||||
<ContextMenu renderMenu={renderMenu} stopImmediatePropagation>
|
||||
<ContextMenu renderMenu={renderMenu}>
|
||||
{(ref) => (
|
||||
<Flex ref={ref} gap={2} alignItems="center" p={2} {...rest}>
|
||||
{children}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const CanvasEntityHeaderCommonActions = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
|
||||
return (
|
||||
<Flex alignSelf="stretch">
|
||||
{entityIdentifier.type !== 'ip_adapter' && <CanvasEntityIsLockedToggle />}
|
||||
<CanvasEntityEnabledToggle />
|
||||
<CanvasEntityDeleteButton />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityHeaderCommonActions.displayName = 'CanvasEntityHeaderCommonActions';
|
||||
@@ -21,7 +21,8 @@ export const CanvasEntityIsLockedToggle = memo(() => {
|
||||
size="sm"
|
||||
aria-label={t(isLocked ? 'controlLayers.locked' : 'controlLayers.unlocked')}
|
||||
tooltip={t(isLocked ? 'controlLayers.locked' : 'controlLayers.unlocked')}
|
||||
variant="ghost"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={isLocked ? <PiLockSimpleFill /> : <PiLockSimpleOpenBold />}
|
||||
onClick={onClick}
|
||||
/>
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { isOk, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { inpaintMaskAdded, rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiStackBold } from 'react-icons/pi';
|
||||
import { serializeError } from 'serialize-error';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
type Props = {
|
||||
type: CanvasEntityIdentifier['type'];
|
||||
};
|
||||
|
||||
export const CanvasEntityMergeVisibleButton = memo(({ type }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const canvasManager = useCanvasManager();
|
||||
const onClick = useCallback(async () => {
|
||||
if (type === 'raster_layer') {
|
||||
const rect = canvasManager.stage.getVisibleRect('raster_layer');
|
||||
const result = await withResultAsync(() =>
|
||||
canvasManager.compositor.rasterizeAndUploadCompositeRasterLayer(rect, false)
|
||||
);
|
||||
|
||||
if (isOk(result)) {
|
||||
dispatch(
|
||||
rasterLayerAdded({
|
||||
isSelected: true,
|
||||
overrides: {
|
||||
objects: [imageDTOToImageObject(result.value)],
|
||||
position: { x: Math.floor(rect.x), y: Math.floor(rect.y) },
|
||||
},
|
||||
deleteOthers: true,
|
||||
})
|
||||
);
|
||||
toast({ title: t('controlLayers.mergeVisibleOk') });
|
||||
} else {
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to merge visible');
|
||||
toast({ title: t('controlLayers.mergeVisibleError'), status: 'error' });
|
||||
}
|
||||
} else if (type === 'inpaint_mask') {
|
||||
const rect = canvasManager.stage.getVisibleRect('inpaint_mask');
|
||||
const result = await withResultAsync(() =>
|
||||
canvasManager.compositor.rasterizeAndUploadCompositeInpaintMask(rect, false)
|
||||
);
|
||||
|
||||
if (isOk(result)) {
|
||||
dispatch(
|
||||
inpaintMaskAdded({
|
||||
isSelected: true,
|
||||
overrides: {
|
||||
objects: [imageDTOToImageObject(result.value)],
|
||||
position: { x: Math.floor(rect.x), y: Math.floor(rect.y) },
|
||||
},
|
||||
deleteOthers: true,
|
||||
})
|
||||
);
|
||||
toast({ title: t('controlLayers.mergeVisibleOk') });
|
||||
} else {
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to merge visible');
|
||||
toast({ title: t('controlLayers.mergeVisibleError'), status: 'error' });
|
||||
}
|
||||
} else {
|
||||
log.error({ type }, 'Unsupported type for merge visible');
|
||||
}
|
||||
}, [canvasManager.compositor, canvasManager.stage, dispatch, t, type]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label={t('controlLayers.mergeVisible')}
|
||||
tooltip={t('controlLayers.mergeVisible')}
|
||||
variant="link"
|
||||
icon={<PiStackBold />}
|
||||
onClick={onClick}
|
||||
alignSelf="stretch"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityMergeVisibleButton.displayName = 'CanvasEntityMergeVisibleButton';
|
||||
@@ -1,16 +1,14 @@
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
import { ButtonGroup, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { canvasRedo, canvasUndo } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasMayRedo, selectCanvasMayUndo } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowClockwiseBold, PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import { useDispatch } from 'react-redux';
|
||||
|
||||
export const UndoRedoButtonGroup = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
export const useCanvasUndoRedo = () => {
|
||||
useAssertSingleton('useCanvasUndoRedo');
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const mayUndo = useAppSelector(selectCanvasMayUndo);
|
||||
@@ -27,27 +25,4 @@ export const UndoRedoButtonGroup = memo(() => {
|
||||
mayRedo,
|
||||
handleRedo,
|
||||
]);
|
||||
|
||||
return (
|
||||
<ButtonGroup isAttached={false}>
|
||||
<IconButton
|
||||
aria-label={t('unifiedCanvas.undo')}
|
||||
tooltip={t('unifiedCanvas.undo')}
|
||||
onClick={handleUndo}
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
isDisabled={!mayUndo}
|
||||
variant="ghost"
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={t('unifiedCanvas.redo')}
|
||||
tooltip={t('unifiedCanvas.redo')}
|
||||
onClick={handleRedo}
|
||||
icon={<PiArrowClockwiseBold />}
|
||||
isDisabled={!mayRedo}
|
||||
variant="ghost"
|
||||
/>
|
||||
</ButtonGroup>
|
||||
);
|
||||
});
|
||||
|
||||
UndoRedoButtonGroup.displayName = 'UndoRedoButtonGroup';
|
||||
};
|
||||
@@ -29,15 +29,15 @@ export const useEntityTitle = (entityIdentifier: CanvasEntityIdentifier) => {
|
||||
|
||||
const parts: string[] = [];
|
||||
if (entityIdentifier.type === 'inpaint_mask') {
|
||||
parts.push(t('controlLayers.inpaintMask', { count: 1 }));
|
||||
parts.push(t('controlLayers.inpaintMask'));
|
||||
} else if (entityIdentifier.type === 'control_layer') {
|
||||
parts.push(t('controlLayers.controlLayer', { count: 1 }));
|
||||
parts.push(t('controlLayers.controlLayer'));
|
||||
} else if (entityIdentifier.type === 'raster_layer') {
|
||||
parts.push(t('controlLayers.rasterLayer', { count: 1 }));
|
||||
parts.push(t('controlLayers.rasterLayer'));
|
||||
} else if (entityIdentifier.type === 'ip_adapter') {
|
||||
parts.push(t('common.ipAdapter', { count: 1 }));
|
||||
parts.push(t('common.ipAdapter'));
|
||||
} else if (entityIdentifier.type === 'regional_guidance') {
|
||||
parts.push(t('controlLayers.regionalGuidance', { count: 1 }));
|
||||
parts.push(t('controlLayers.regionalGuidance'));
|
||||
} else {
|
||||
assert(false, 'Unexpected entity type');
|
||||
}
|
||||
|
||||
@@ -8,15 +8,15 @@ export const useEntityTypeString = (type: CanvasEntityIdentifier['type']): strin
|
||||
const typeString = useMemo(() => {
|
||||
switch (type) {
|
||||
case 'control_layer':
|
||||
return t('controlLayers.controlLayer', { count: 0 });
|
||||
return t('controlLayers.controlLayer');
|
||||
case 'raster_layer':
|
||||
return t('controlLayers.rasterLayer', { count: 0 });
|
||||
return t('controlLayers.rasterLayer');
|
||||
case 'inpaint_mask':
|
||||
return t('controlLayers.inpaintMask', { count: 0 });
|
||||
return t('controlLayers.inpaintMask');
|
||||
case 'regional_guidance':
|
||||
return t('controlLayers.regionalGuidance', { count: 0 });
|
||||
return t('controlLayers.regionalGuidance');
|
||||
case 'ip_adapter':
|
||||
return t('controlLayers.ipAdapter', { count: 0 });
|
||||
return t('controlLayers.ipAdapter');
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
|
||||
@@ -147,6 +147,19 @@ export class CanvasCompositorModule extends CanvasModuleABC {
|
||||
return stableHash(data);
|
||||
};
|
||||
|
||||
rasterizeAndUploadCompositeRasterLayer = async (rect: Rect, saveToGallery: boolean) => {
|
||||
this.log.trace({ rect }, 'Rasterizing composite raster layer');
|
||||
|
||||
const canvas = this.getCompositeRasterLayerCanvas(rect);
|
||||
const blob = await canvasToBlob(canvas);
|
||||
|
||||
if (this.manager._isDebugging) {
|
||||
previewBlob(blob, 'Composite raster layer canvas');
|
||||
}
|
||||
|
||||
return uploadImage(blob, 'composite-raster-layer.png', 'general', !saveToGallery);
|
||||
};
|
||||
|
||||
getCompositeRasterLayerImageDTO = async (rect: Rect): Promise<ImageDTO> => {
|
||||
let imageDTO: ImageDTO | null = null;
|
||||
|
||||
@@ -161,19 +174,23 @@ export class CanvasCompositorModule extends CanvasModuleABC {
|
||||
}
|
||||
}
|
||||
|
||||
this.log.trace({ rect }, 'Rasterizing composite raster layer');
|
||||
|
||||
const canvas = this.getCompositeRasterLayerCanvas(rect);
|
||||
const blob = await canvasToBlob(canvas);
|
||||
if (this.manager._isDebugging) {
|
||||
previewBlob(blob, 'Composite raster layer canvas');
|
||||
}
|
||||
|
||||
imageDTO = await uploadImage(blob, 'composite-raster-layer.png', 'general', true);
|
||||
imageDTO = await this.rasterizeAndUploadCompositeRasterLayer(rect, false);
|
||||
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
rasterizeAndUploadCompositeInpaintMask = async (rect: Rect, saveToGallery: boolean) => {
|
||||
this.log.trace({ rect }, 'Rasterizing composite inpaint mask');
|
||||
|
||||
const canvas = this.getCompositeInpaintMaskCanvas(rect);
|
||||
const blob = await canvasToBlob(canvas);
|
||||
if (this.manager._isDebugging) {
|
||||
previewBlob(blob, 'Composite inpaint mask canvas');
|
||||
}
|
||||
|
||||
return uploadImage(blob, 'composite-inpaint-mask.png', 'general', !saveToGallery);
|
||||
};
|
||||
|
||||
getCompositeInpaintMaskImageDTO = async (rect: Rect): Promise<ImageDTO> => {
|
||||
let imageDTO: ImageDTO | null = null;
|
||||
|
||||
@@ -188,15 +205,7 @@ export class CanvasCompositorModule extends CanvasModuleABC {
|
||||
}
|
||||
}
|
||||
|
||||
this.log.trace({ rect }, 'Rasterizing composite inpaint mask');
|
||||
|
||||
const canvas = this.getCompositeInpaintMaskCanvas(rect);
|
||||
const blob = await canvasToBlob(canvas);
|
||||
if (this.manager._isDebugging) {
|
||||
previewBlob(blob, 'Composite inpaint mask canvas');
|
||||
}
|
||||
|
||||
imageDTO = await uploadImage(blob, 'composite-inpaint-mask.png', 'general', true);
|
||||
imageDTO = await this.rasterizeAndUploadCompositeInpaintMask(rect, false);
|
||||
this.manager.cache.imageNameCache.set(hash, imageDTO.image_name);
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleABC } from 'features/controlLayers/konva/CanvasModuleABC';
|
||||
import { CANVAS_SCALE_BY } from 'features/controlLayers/konva/constants';
|
||||
import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
|
||||
import type { Coordinate, Dimensions, Rect } from 'features/controlLayers/store/types';
|
||||
import type { CanvasEntityIdentifier, Coordinate, Dimensions, Rect } from 'features/controlLayers/store/types';
|
||||
import type Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import { clamp } from 'lodash-es';
|
||||
@@ -76,35 +76,36 @@ export class CanvasStageModule extends CanvasModuleABC {
|
||||
});
|
||||
};
|
||||
|
||||
getVisibleRect = (): Rect => {
|
||||
getVisibleRect = (type?: Exclude<CanvasEntityIdentifier['type'], 'ip_adapter'>): Rect => {
|
||||
const rects = [];
|
||||
|
||||
for (const adapter of this.manager.adapters.getAll()) {
|
||||
if (adapter.state.isEnabled) {
|
||||
rects.push(adapter.transformer.getRelativeRect());
|
||||
if (!adapter.state.isEnabled) {
|
||||
continue;
|
||||
}
|
||||
if (type && adapter.state.type !== type) {
|
||||
continue;
|
||||
}
|
||||
rects.push(adapter.transformer.getRelativeRect());
|
||||
}
|
||||
|
||||
const rectUnion = getRectUnion(...rects);
|
||||
|
||||
if (rectUnion.width === 0 || rectUnion.height === 0) {
|
||||
// fall back to the bbox if there is no content
|
||||
return this.manager.stateApi.getBbox().rect;
|
||||
} else {
|
||||
return rectUnion;
|
||||
}
|
||||
return getRectUnion(...rects);
|
||||
};
|
||||
|
||||
fitBboxToStage = () => {
|
||||
this.log.trace('Fitting bbox to stage');
|
||||
const bbox = this.manager.stateApi.getBbox();
|
||||
this.fitRect(bbox.rect);
|
||||
const { rect } = this.manager.stateApi.getBbox();
|
||||
this.log.trace({ rect }, 'Fitting bbox to stage');
|
||||
this.fitRect(rect);
|
||||
};
|
||||
|
||||
fitLayersToStage() {
|
||||
this.log.trace('Fitting layers to stage');
|
||||
const rect = this.getVisibleRect();
|
||||
this.fitRect(rect);
|
||||
if (rect.width === 0 || rect.height === 0) {
|
||||
this.fitBboxToStage();
|
||||
} else {
|
||||
this.log.trace({ rect }, 'Fitting layers to stage');
|
||||
this.fitRect(rect);
|
||||
}
|
||||
}
|
||||
|
||||
fitRect = (rect: Rect) => {
|
||||
|
||||
@@ -250,12 +250,10 @@ export class CanvasToolModule extends CanvasModuleABC {
|
||||
this.konva.colorPicker.group.visible(tool === 'colorPicker');
|
||||
};
|
||||
|
||||
render = () => {
|
||||
syncCursorStyle = () => {
|
||||
const stage = this.manager.stage;
|
||||
const renderedEntityCount = this.manager.stateApi.getRenderedEntityCount();
|
||||
const toolState = this.manager.stateApi.getToolState();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntity();
|
||||
const cursorPos = this.manager.stateApi.$lastCursorPos.get();
|
||||
const isMouseDown = this.manager.stateApi.$isMouseDown.get();
|
||||
const tool = this.manager.stateApi.$tool.get();
|
||||
|
||||
@@ -294,6 +292,158 @@ export class CanvasToolModule extends CanvasModuleABC {
|
||||
// Non-drawable layers don't have tools
|
||||
stage.container.style.cursor = 'not-allowed';
|
||||
}
|
||||
};
|
||||
|
||||
renderBrushTool = (cursorPos: Coordinate) => {
|
||||
const toolState = this.manager.stateApi.getToolState();
|
||||
const brushPreviewFill = this.manager.stateApi.getBrushPreviewFill();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos, toolState.brush.width);
|
||||
const onePixel = this.manager.stage.getScaledPixels(1);
|
||||
const twoPixels = this.manager.stage.getScaledPixels(2);
|
||||
const radius = toolState.brush.width / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.brush.fillCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
fill: rgbaColorToString(brushPreviewFill),
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
this.konva.brush.innerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.brush.outerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
};
|
||||
|
||||
renderEraserTool = (cursorPos: Coordinate) => {
|
||||
const toolState = this.manager.stateApi.getToolState();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos, toolState.eraser.width);
|
||||
const onePixel = this.manager.stage.getScaledPixels(1);
|
||||
const twoPixels = this.manager.stage.getScaledPixels(2);
|
||||
const radius = toolState.eraser.width / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.eraser.fillCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
fill: 'white',
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
this.konva.eraser.innerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.eraser.outerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
};
|
||||
|
||||
renderColorPicker = (cursorPos: Coordinate) => {
|
||||
const toolState = this.manager.stateApi.getToolState();
|
||||
const colorUnderCursor = this.manager.stateApi.$colorUnderCursor.get();
|
||||
const colorPickerInnerRadius = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_RADIUS);
|
||||
const colorPickerOuterRadius = this.manager.stage.getScaledPixels(
|
||||
CanvasToolModule.COLOR_PICKER_RADIUS + CanvasToolModule.COLOR_PICKER_THICKNESS
|
||||
);
|
||||
const onePixel = this.manager.stage.getScaledPixels(1);
|
||||
const twoPixels = this.manager.stage.getScaledPixels(2);
|
||||
|
||||
this.konva.colorPicker.newColor.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
fill: rgbColorToString(colorUnderCursor),
|
||||
innerRadius: colorPickerInnerRadius,
|
||||
outerRadius: colorPickerOuterRadius,
|
||||
});
|
||||
this.konva.colorPicker.oldColor.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
fill: rgbColorToString(toolState.fill),
|
||||
innerRadius: colorPickerInnerRadius,
|
||||
outerRadius: colorPickerOuterRadius,
|
||||
});
|
||||
this.konva.colorPicker.innerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: colorPickerOuterRadius,
|
||||
outerRadius: colorPickerOuterRadius + onePixel,
|
||||
});
|
||||
this.konva.colorPicker.outerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: colorPickerOuterRadius + onePixel,
|
||||
outerRadius: colorPickerOuterRadius + twoPixels,
|
||||
});
|
||||
|
||||
const size = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_CROSSHAIR_SIZE);
|
||||
const space = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_CROSSHAIR_SPACE);
|
||||
const innerThickness = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_CROSSHAIR_INNER_THICKNESS);
|
||||
const outerThickness = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_CROSSHAIR_OUTER_THICKNESS);
|
||||
this.konva.colorPicker.crosshairNorthOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x, cursorPos.y - size, cursorPos.x, cursorPos.y - space],
|
||||
});
|
||||
this.konva.colorPicker.crosshairNorthInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x, cursorPos.y - size, cursorPos.x, cursorPos.y - space],
|
||||
});
|
||||
this.konva.colorPicker.crosshairEastOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x + space, cursorPos.y, cursorPos.x + size, cursorPos.y],
|
||||
});
|
||||
this.konva.colorPicker.crosshairEastInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x + space, cursorPos.y, cursorPos.x + size, cursorPos.y],
|
||||
});
|
||||
this.konva.colorPicker.crosshairSouthOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x, cursorPos.y + space, cursorPos.x, cursorPos.y + size],
|
||||
});
|
||||
this.konva.colorPicker.crosshairSouthInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x, cursorPos.y + space, cursorPos.x, cursorPos.y + size],
|
||||
});
|
||||
this.konva.colorPicker.crosshairWestOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x - space, cursorPos.y, cursorPos.x - size, cursorPos.y],
|
||||
});
|
||||
this.konva.colorPicker.crosshairWestInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x - space, cursorPos.y, cursorPos.x - size, cursorPos.y],
|
||||
});
|
||||
};
|
||||
|
||||
render = () => {
|
||||
const stage = this.manager.stage;
|
||||
const renderedEntityCount = this.manager.stateApi.getRenderedEntityCount();
|
||||
const selectedEntity = this.manager.stateApi.getSelectedEntity();
|
||||
const cursorPos = this.manager.stateApi.$lastCursorPos.get();
|
||||
const tool = this.manager.stateApi.$tool.get();
|
||||
|
||||
const isDrawable =
|
||||
!!selectedEntity &&
|
||||
selectedEntity.state.isEnabled &&
|
||||
!selectedEntity.state.isLocked &&
|
||||
isDrawableEntity(selectedEntity.state);
|
||||
|
||||
this.syncCursorStyle();
|
||||
|
||||
stage.setIsDraggable(tool === 'view');
|
||||
|
||||
@@ -305,136 +455,11 @@ export class CanvasToolModule extends CanvasModuleABC {
|
||||
|
||||
// No need to render the brush preview if the cursor position or color is missing
|
||||
if (cursorPos && tool === 'brush') {
|
||||
const brushPreviewFill = this.manager.stateApi.getBrushPreviewFill();
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos, toolState.brush.width);
|
||||
const onePixel = this.manager.stage.getScaledPixels(1);
|
||||
const twoPixels = this.manager.stage.getScaledPixels(2);
|
||||
const radius = toolState.brush.width / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.brush.fillCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
fill: rgbaColorToString(brushPreviewFill),
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
this.konva.brush.innerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.brush.outerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
this.renderBrushTool(cursorPos);
|
||||
} else if (cursorPos && tool === 'eraser') {
|
||||
const alignedCursorPos = alignCoordForTool(cursorPos, toolState.eraser.width);
|
||||
const onePixel = this.manager.stage.getScaledPixels(1);
|
||||
const twoPixels = this.manager.stage.getScaledPixels(2);
|
||||
const radius = toolState.eraser.width / 2;
|
||||
|
||||
// The circle is scaled
|
||||
this.konva.eraser.fillCircle.setAttrs({
|
||||
x: alignedCursorPos.x,
|
||||
y: alignedCursorPos.y,
|
||||
radius,
|
||||
fill: 'white',
|
||||
});
|
||||
|
||||
// But the borders are in screen-pixels
|
||||
this.konva.eraser.innerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius,
|
||||
outerRadius: radius + onePixel,
|
||||
});
|
||||
this.konva.eraser.outerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: radius + onePixel,
|
||||
outerRadius: radius + twoPixels,
|
||||
});
|
||||
this.renderEraserTool(cursorPos);
|
||||
} else if (cursorPos && tool === 'colorPicker') {
|
||||
const colorUnderCursor = this.manager.stateApi.$colorUnderCursor.get();
|
||||
const colorPickerInnerRadius = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_RADIUS);
|
||||
const colorPickerOuterRadius = this.manager.stage.getScaledPixels(
|
||||
CanvasToolModule.COLOR_PICKER_RADIUS + CanvasToolModule.COLOR_PICKER_THICKNESS
|
||||
);
|
||||
const onePixel = this.manager.stage.getScaledPixels(1);
|
||||
const twoPixels = this.manager.stage.getScaledPixels(2);
|
||||
|
||||
this.konva.colorPicker.newColor.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
fill: rgbColorToString(colorUnderCursor),
|
||||
innerRadius: colorPickerInnerRadius,
|
||||
outerRadius: colorPickerOuterRadius,
|
||||
});
|
||||
this.konva.colorPicker.oldColor.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
fill: rgbColorToString(toolState.fill),
|
||||
innerRadius: colorPickerInnerRadius,
|
||||
outerRadius: colorPickerOuterRadius,
|
||||
});
|
||||
this.konva.colorPicker.innerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: colorPickerOuterRadius,
|
||||
outerRadius: colorPickerOuterRadius + onePixel,
|
||||
});
|
||||
this.konva.colorPicker.outerBorder.setAttrs({
|
||||
x: cursorPos.x,
|
||||
y: cursorPos.y,
|
||||
innerRadius: colorPickerOuterRadius + onePixel,
|
||||
outerRadius: colorPickerOuterRadius + twoPixels,
|
||||
});
|
||||
|
||||
const size = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_CROSSHAIR_SIZE);
|
||||
const space = this.manager.stage.getScaledPixels(CanvasToolModule.COLOR_PICKER_CROSSHAIR_SPACE);
|
||||
const innerThickness = this.manager.stage.getScaledPixels(
|
||||
CanvasToolModule.COLOR_PICKER_CROSSHAIR_INNER_THICKNESS
|
||||
);
|
||||
const outerThickness = this.manager.stage.getScaledPixels(
|
||||
CanvasToolModule.COLOR_PICKER_CROSSHAIR_OUTER_THICKNESS
|
||||
);
|
||||
this.konva.colorPicker.crosshairNorthOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x, cursorPos.y - size, cursorPos.x, cursorPos.y - space],
|
||||
});
|
||||
this.konva.colorPicker.crosshairNorthInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x, cursorPos.y - size, cursorPos.x, cursorPos.y - space],
|
||||
});
|
||||
this.konva.colorPicker.crosshairEastOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x + space, cursorPos.y, cursorPos.x + size, cursorPos.y],
|
||||
});
|
||||
this.konva.colorPicker.crosshairEastInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x + space, cursorPos.y, cursorPos.x + size, cursorPos.y],
|
||||
});
|
||||
this.konva.colorPicker.crosshairSouthOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x, cursorPos.y + space, cursorPos.x, cursorPos.y + size],
|
||||
});
|
||||
this.konva.colorPicker.crosshairSouthInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x, cursorPos.y + space, cursorPos.x, cursorPos.y + size],
|
||||
});
|
||||
this.konva.colorPicker.crosshairWestOuter.setAttrs({
|
||||
strokeWidth: outerThickness,
|
||||
points: [cursorPos.x - space, cursorPos.y, cursorPos.x - size, cursorPos.y],
|
||||
});
|
||||
this.konva.colorPicker.crosshairWestInner.setAttrs({
|
||||
strokeWidth: innerThickness,
|
||||
points: [cursorPos.x - space, cursorPos.y, cursorPos.x - size, cursorPos.y],
|
||||
});
|
||||
this.renderColorPicker(cursorPos);
|
||||
}
|
||||
|
||||
this.setToolVisibility(tool, isDrawable);
|
||||
@@ -864,6 +889,10 @@ export class CanvasToolModule extends CanvasModuleABC {
|
||||
this.manager.stateApi.$spaceKey.set(true);
|
||||
this.manager.stateApi.$lastCursorPos.set(null);
|
||||
this.manager.stateApi.$lastMouseDownPos.set(null);
|
||||
} else if (e.key === 'Alt') {
|
||||
// Select the color picker on alt key down
|
||||
this.manager.stateApi.$toolBuffer.set(this.manager.stateApi.$tool.get());
|
||||
this.manager.stateApi.$tool.set('colorPicker');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -880,6 +909,11 @@ export class CanvasToolModule extends CanvasModuleABC {
|
||||
this.manager.stateApi.$tool.set(toolBuffer ?? 'move');
|
||||
this.manager.stateApi.$toolBuffer.set(null);
|
||||
this.manager.stateApi.$spaceKey.set(false);
|
||||
} else if (e.key === 'Alt') {
|
||||
// Revert the tool to the previous tool on alt key up
|
||||
const toolBuffer = this.manager.stateApi.$toolBuffer.get();
|
||||
this.manager.stateApi.$tool.set(toolBuffer ?? 'move');
|
||||
this.manager.stateApi.$toolBuffer.set(null);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('util', () => {
|
||||
describe('getPrefixedId', () => {
|
||||
it('should return a prefixed id', () => {
|
||||
expect(getPrefixedId('foo').split(':')[0]).toBe('foo');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRectUnion', () => {
|
||||
it('should return the union of rects (2 rects)', () => {
|
||||
const rect1 = { x: 0, y: 0, width: 10, height: 10 };
|
||||
const rect2 = { x: 5, y: 5, width: 10, height: 10 };
|
||||
const union = getRectUnion(rect1, rect2);
|
||||
expect(union).toEqual({ x: 0, y: 0, width: 15, height: 15 });
|
||||
});
|
||||
it('should return the union of rects (3 rects)', () => {
|
||||
const rect1 = { x: 0, y: 0, width: 10, height: 10 };
|
||||
const rect2 = { x: 5, y: 5, width: 10, height: 10 };
|
||||
const rect3 = { x: 10, y: 10, width: 10, height: 10 };
|
||||
const union = getRectUnion(rect1, rect2, rect3);
|
||||
expect(union).toEqual({ x: 0, y: 0, width: 20, height: 20 });
|
||||
});
|
||||
it('should return the union of rects (2 rects none from zero)', () => {
|
||||
const rect1 = { x: 5, y: 5, width: 10, height: 10 };
|
||||
const rect2 = { x: 10, y: 10, width: 10, height: 10 };
|
||||
const union = getRectUnion(rect1, rect2);
|
||||
expect(union).toEqual({ x: 5, y: 5, width: 15, height: 15 });
|
||||
});
|
||||
it('should return the union of rects (2 rects with negative x/y)', () => {
|
||||
const rect1 = { x: -5, y: -5, width: 10, height: 10 };
|
||||
const rect2 = { x: 0, y: 0, width: 10, height: 10 };
|
||||
const union = getRectUnion(rect1, rect2);
|
||||
expect(union).toEqual({ x: -5, y: -5, width: 15, height: 15 });
|
||||
});
|
||||
it('should return the union of the first rect if only one rect is provided', () => {
|
||||
const rect = { x: 0, y: 0, width: 10, height: 10 };
|
||||
const union = getRectUnion(rect);
|
||||
expect(union).toEqual(rect);
|
||||
});
|
||||
it('should fall back on an empty rect if no rects are provided', () => {
|
||||
const union = getRectUnion();
|
||||
expect(union).toEqual({ x: 0, y: 0, width: 0, height: 0 });
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -302,10 +302,13 @@ export const konvaNodeToCanvas = (node: Konva.Node, bbox?: Rect): HTMLCanvasElem
|
||||
* @returns A Promise that resolves with Blob of the node cropped to the bounding box
|
||||
*/
|
||||
export const canvasToBlob = (canvas: HTMLCanvasElement): Promise<Blob> => {
|
||||
return new Promise((resolve) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
canvas.toBlob((blob) => {
|
||||
assert(blob, 'blob is null');
|
||||
resolve(blob);
|
||||
if (!blob) {
|
||||
reject('Failed to convert canvas to blob');
|
||||
} else {
|
||||
resolve(blob);
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
@@ -418,19 +421,25 @@ export function snapToNearest(value: number, candidateValues: number[], threshol
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the union of two rects
|
||||
* @param rect1 The first rect
|
||||
* @param rect2 The second rect
|
||||
* Gets the union of any number of rects.
|
||||
* @params rects The rects to union
|
||||
* @returns The union of the two rects
|
||||
*/
|
||||
export const getRectUnion = (...rects: Rect[]): Rect => {
|
||||
const firstRect = rects.shift();
|
||||
|
||||
if (!firstRect) {
|
||||
return getEmptyRect();
|
||||
}
|
||||
|
||||
const rect = rects.reduce<Rect>((acc, r) => {
|
||||
const x = Math.min(acc.x, r.x);
|
||||
const y = Math.min(acc.y, r.y);
|
||||
const width = Math.max(acc.x + acc.width, r.x + r.width) - x;
|
||||
const height = Math.max(acc.y + acc.height, r.y + r.height) - y;
|
||||
return { x, y, width, height };
|
||||
}, getEmptyRect());
|
||||
}, firstRect);
|
||||
|
||||
return rect;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import { createAction, createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { canvasSlice } from 'features/controlLayers/store/canvasSlice';
|
||||
import type { SessionMode, StagingAreaImage } from 'features/controlLayers/store/types';
|
||||
import type { StagingAreaImage } from 'features/controlLayers/store/types';
|
||||
|
||||
export type CanvasSessionState = {
|
||||
mode: SessionMode;
|
||||
sendToCanvas: boolean;
|
||||
isStaging: boolean;
|
||||
stagedImages: StagingAreaImage[];
|
||||
selectedStagedImageIndex: number;
|
||||
};
|
||||
|
||||
const initialState: CanvasSessionState = {
|
||||
mode: 'generate',
|
||||
sendToCanvas: false,
|
||||
isStaging: false,
|
||||
stagedImages: [],
|
||||
selectedStagedImageIndex: 0,
|
||||
@@ -27,6 +27,7 @@ export const canvasSessionSlice = createSlice({
|
||||
},
|
||||
sessionImageStaged: (state, action: PayloadAction<{ stagingAreaImage: StagingAreaImage }>) => {
|
||||
const { stagingAreaImage } = action.payload;
|
||||
state.isStaging = true;
|
||||
state.stagedImages.push(stagingAreaImage);
|
||||
state.selectedStagedImageIndex = state.stagedImages.length - 1;
|
||||
},
|
||||
@@ -50,9 +51,8 @@ export const canvasSessionSlice = createSlice({
|
||||
state.stagedImages = [];
|
||||
state.selectedStagedImageIndex = 0;
|
||||
},
|
||||
sessionModeChanged: (state, action: PayloadAction<{ mode: SessionMode }>) => {
|
||||
const { mode } = action.payload;
|
||||
state.mode = mode;
|
||||
sessionSendToCanvasChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.sendToCanvas = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -64,7 +64,7 @@ export const {
|
||||
sessionStagingAreaReset,
|
||||
sessionNextStagedImageSelected,
|
||||
sessionPrevStagedImageSelected,
|
||||
sessionModeChanged,
|
||||
sessionSendToCanvasChanged,
|
||||
} = canvasSessionSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
@@ -85,3 +85,7 @@ export const sessionStagingAreaImageAccepted = createAction<{ index: number }>(
|
||||
export const selectCanvasSessionSlice = (s: RootState) => s.canvasSession;
|
||||
|
||||
export const selectIsStaging = createSelector(selectCanvasSessionSlice, (canvasSession) => canvasSession.isStaging);
|
||||
export const selectIsComposing = createSelector(
|
||||
selectCanvasSessionSlice,
|
||||
(canvasSession) => canvasSession.sendToCanvas
|
||||
);
|
||||
|
||||
@@ -35,10 +35,14 @@ export const canvasSettingsSlice = createSlice({
|
||||
settingsAutoSaveToggled: (state) => {
|
||||
state.autoSave = !state.autoSave;
|
||||
},
|
||||
settingsShowHUDToggled: (state) => {
|
||||
state.showHUD = !state.showHUD;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { clipToBboxChanged, settingsAutoSaveToggled, settingsDynamicGridToggled } = canvasSettingsSlice.actions;
|
||||
export const { clipToBboxChanged, settingsAutoSaveToggled, settingsDynamicGridToggled, settingsShowHUDToggled } =
|
||||
canvasSettingsSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
|
||||
@@ -123,9 +123,14 @@ export const canvasSlice = createSlice({
|
||||
rasterLayerAdded: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }>
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
overrides?: Partial<CanvasRasterLayerState>;
|
||||
isSelected?: boolean;
|
||||
deleteOthers?: boolean;
|
||||
}>
|
||||
) => {
|
||||
const { id, overrides, isSelected } = action.payload;
|
||||
const { id, overrides, isSelected, deleteOthers } = action.payload;
|
||||
const entity: CanvasRasterLayerState = {
|
||||
id,
|
||||
name: null,
|
||||
@@ -137,12 +142,25 @@ export const canvasSlice = createSlice({
|
||||
position: { x: 0, y: 0 },
|
||||
};
|
||||
merge(entity, overrides);
|
||||
state.rasterLayers.entities.push(entity);
|
||||
|
||||
if (deleteOthers) {
|
||||
state.rasterLayers.entities = [entity];
|
||||
} else {
|
||||
state.rasterLayers.entities.push(entity);
|
||||
}
|
||||
|
||||
if (isSelected) {
|
||||
state.selectedEntityIdentifier = getEntityIdentifier(entity);
|
||||
}
|
||||
},
|
||||
prepare: (payload: { overrides?: Partial<CanvasRasterLayerState>; isSelected?: boolean }) => ({
|
||||
prepare: (payload: {
|
||||
overrides?: Partial<CanvasRasterLayerState>;
|
||||
isSelected?: boolean;
|
||||
/**
|
||||
* asdf
|
||||
*/
|
||||
deleteOthers?: boolean;
|
||||
}) => ({
|
||||
payload: { ...payload, id: getPrefixedId('raster_layer') },
|
||||
}),
|
||||
},
|
||||
@@ -603,9 +621,14 @@ export const canvasSlice = createSlice({
|
||||
inpaintMaskAdded: {
|
||||
reducer: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; overrides?: Partial<CanvasInpaintMaskState>; isSelected?: boolean }>
|
||||
action: PayloadAction<{
|
||||
id: string;
|
||||
overrides?: Partial<CanvasInpaintMaskState>;
|
||||
isSelected?: boolean;
|
||||
deleteOthers?: boolean;
|
||||
}>
|
||||
) => {
|
||||
const { id, overrides, isSelected } = action.payload;
|
||||
const { id, overrides, isSelected, deleteOthers } = action.payload;
|
||||
const entity: CanvasInpaintMaskState = {
|
||||
id,
|
||||
name: null,
|
||||
@@ -621,12 +644,22 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
};
|
||||
merge(entity, overrides);
|
||||
state.inpaintMasks.entities.push(entity);
|
||||
|
||||
if (deleteOthers) {
|
||||
state.inpaintMasks.entities = [entity];
|
||||
} else {
|
||||
state.inpaintMasks.entities.push(entity);
|
||||
}
|
||||
|
||||
if (isSelected) {
|
||||
state.selectedEntityIdentifier = getEntityIdentifier(entity);
|
||||
}
|
||||
},
|
||||
prepare: (payload?: { overrides?: Partial<CanvasInpaintMaskState>; isSelected?: boolean }) => ({
|
||||
prepare: (payload?: {
|
||||
overrides?: Partial<CanvasInpaintMaskState>;
|
||||
isSelected?: boolean;
|
||||
deleteOthers?: boolean;
|
||||
}) => ({
|
||||
payload: { ...payload, id: getPrefixedId('inpaint_mask') },
|
||||
}),
|
||||
},
|
||||
|
||||
@@ -685,8 +685,6 @@ export type StagingAreaImage = {
|
||||
offsetY: number;
|
||||
};
|
||||
|
||||
export type SessionMode = 'generate' | 'compose';
|
||||
|
||||
export type CanvasState = {
|
||||
_version: 3;
|
||||
selectedEntityIdentifier: CanvasEntityIdentifier | null;
|
||||
|
||||
@@ -1,55 +1,61 @@
|
||||
import { ButtonGroup, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { IconSwitch } from 'common/components/IconSwitch';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiEyeBold, PiPencilBold } from 'react-icons/pi';
|
||||
|
||||
export const ViewerToggleMenu = () => {
|
||||
const TooltipEdit = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('common.edit')}</Text>
|
||||
<Text fontWeight="normal">{t('common.editDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TooltipEdit.displayName = 'TooltipEdit';
|
||||
|
||||
const TooltipView = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('common.view')}</Text>
|
||||
<Text fontWeight="normal">{t('common.viewDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
TooltipView.displayName = 'TooltipView';
|
||||
|
||||
export const ViewerToggle = memo(() => {
|
||||
const imageViewer = useImageViewer();
|
||||
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
|
||||
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
|
||||
const onChange = useCallback(
|
||||
(isChecked: boolean) => {
|
||||
if (isChecked) {
|
||||
imageViewer.onClose();
|
||||
} else {
|
||||
imageViewer.onOpen();
|
||||
}
|
||||
},
|
||||
[imageViewer]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex gap={4} alignItems="center" justifyContent="center">
|
||||
<ButtonGroup size="md">
|
||||
<Tooltip
|
||||
hasArrow
|
||||
label={
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('common.viewing')}</Text>
|
||||
<Text fontWeight="normal">{t('common.viewingDesc')}</Text>
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
icon={<PiEyeBold />}
|
||||
onClick={imageViewer.onOpen}
|
||||
variant={imageViewer.isOpen ? 'solid' : 'outline'}
|
||||
colorScheme={imageViewer.isOpen ? 'invokeBlue' : 'base'}
|
||||
aria-label={t('common.viewing')}
|
||||
w={12}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip
|
||||
hasArrow
|
||||
label={
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('common.editing')}</Text>
|
||||
<Text fontWeight="normal">{t('common.editingDesc')}</Text>
|
||||
</Flex>
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
icon={<PiPencilBold />}
|
||||
onClick={imageViewer.onClose}
|
||||
variant={!imageViewer.isOpen ? 'solid' : 'outline'}
|
||||
colorScheme={!imageViewer.isOpen ? 'invokeBlue' : 'base'}
|
||||
aria-label={t('common.editing')}
|
||||
w={12}
|
||||
/>
|
||||
</Tooltip>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<IconSwitch
|
||||
isChecked={!imageViewer.isOpen}
|
||||
onChange={onChange}
|
||||
iconUnchecked={<PiEyeBold />}
|
||||
tooltipUnchecked={<TooltipView />}
|
||||
iconChecked={<PiPencilBold />}
|
||||
tooltipChecked={<TooltipEdit />}
|
||||
ariaLabel="Toggle viewer"
|
||||
/>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ViewerToggle.displayName = 'ViewerToggle';
|
||||
|
||||
@@ -7,7 +7,7 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||
import { ViewerToggle } from './ViewerToggleMenu';
|
||||
|
||||
const selectShowToggle = createSelector(selectActiveTab, (tab) => {
|
||||
if (tab === 'upscaling' || tab === 'workflows') {
|
||||
@@ -31,7 +31,7 @@ export const ViewerToolbar = memo(() => {
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
{showToggle && <ViewerToggleMenu />}
|
||||
{showToggle && <ViewerToggle />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
@@ -179,12 +179,12 @@ const ModelList = () => {
|
||||
{/* T5 Encoders List */}
|
||||
{isLoadingT5EncoderModels && <FetchingModelsLoader loadingMessage="Loading T5 Encoder Models..." />}
|
||||
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
||||
<ModelListWrapper title="T5 Encoder" modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
)}
|
||||
{/* Clip Embed List */}
|
||||
{isLoadingClipEmbedModels && <FetchingModelsLoader loadingMessage="Loading Clip Embed Models..." />}
|
||||
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
|
||||
<ModelListWrapper title="Clip Embed" modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
|
||||
)}
|
||||
{/* Spandrel Image to Image List */}
|
||||
{isLoadingSpandrelImageToImageModels && (
|
||||
@@ -192,7 +192,7 @@ const ModelList = () => {
|
||||
)}
|
||||
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Image-to-Image"
|
||||
title={t('modelManager.spandrelImageToImage')}
|
||||
modelList={filteredSpandrelImageToImageModels}
|
||||
key="spandrel-image-to-image"
|
||||
/>
|
||||
|
||||
@@ -19,11 +19,10 @@ export const ModelTypeFilter = memo(() => {
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
t5_encoder: 'T5Encoder',
|
||||
clip_embed: 'Clip Embed',
|
||||
t5_encoder: t('modelManager.t5Encoder'),
|
||||
clip_embed: t('modelManager.clipEmbed'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'Clip Vision',
|
||||
spandrel_image_to_image: 'Image-to-Image',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Flex } from '@invoke-ai/ui-library';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { AddNodeCmdk } from 'features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk';
|
||||
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
|
||||
import WorkflowEditorSettings from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { LoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
|
||||
import { SaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/SaveWorkflowAsDialog';
|
||||
import { memo } from 'react';
|
||||
@@ -39,6 +40,7 @@ const NodeEditor = () => {
|
||||
<LoadWorkflowFromGraphModal />
|
||||
</>
|
||||
)}
|
||||
<WorkflowEditorSettings />
|
||||
{isLoading && <IAINoContentFallback label={t('nodes.loadingNodes')} icon={MdDeviceHub} />}
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
@@ -33,6 +33,7 @@ import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memoize } from 'lodash-es';
|
||||
import { computed } from 'nanostores';
|
||||
import type { ChangeEvent } from 'react';
|
||||
@@ -162,13 +163,13 @@ const cmdkRootSx: SystemStyleObject = {
|
||||
export const AddNodeCmdk = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const addNodeCmdk = useAddNodeCmdk();
|
||||
const addNodeCmdkIsOpen = useStore(addNodeCmdk.$boolean);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const addNode = useAddNode();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const throttledSearchTerm = useThrottle(searchTerm, 100);
|
||||
|
||||
useHotkeys(['shift+a', 'space'], addNodeCmdk.setTrue, { preventDefault: true });
|
||||
useHotkeys(['shift+a', 'space'], addNodeCmdk.setTrue, { enabled: tab === 'workflows', preventDefault: true }, [tab]);
|
||||
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
@@ -190,7 +191,7 @@ export const AddNodeCmdk = memo(() => {
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isOpen={addNodeCmdkIsOpen}
|
||||
isOpen={addNodeCmdk.isTrue}
|
||||
onClose={onClose}
|
||||
useInert={false}
|
||||
initialFocusRef={inputRef}
|
||||
|
||||
@@ -6,6 +6,8 @@ import {
|
||||
isBoardFieldInputTemplate,
|
||||
isBooleanFieldInputInstance,
|
||||
isBooleanFieldInputTemplate,
|
||||
isCLIPEmbedModelFieldInputInstance,
|
||||
isCLIPEmbedModelFieldInputTemplate,
|
||||
isColorFieldInputInstance,
|
||||
isColorFieldInputTemplate,
|
||||
isControlNetModelFieldInputInstance,
|
||||
@@ -16,6 +18,8 @@ import {
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@@ -49,10 +53,12 @@ import { memo } from 'react';
|
||||
|
||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
|
||||
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
|
||||
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
|
||||
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
|
||||
@@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { ClipEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
|
||||
|
||||
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useClipEmbedModels();
|
||||
const _onChange = useCallback(
|
||||
(value: ClipEmbedModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldCLIPEmbedValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CLIPEmbedModelFieldInputComponent);
|
||||
@@ -0,0 +1,60 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
|
||||
|
||||
const FluxVAEModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useFluxVAEModels();
|
||||
const _onChange = useCallback(
|
||||
(value: VAEModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldFluxVAEModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(FluxVAEModelFieldInputComponent);
|
||||
@@ -14,9 +14,9 @@ import {
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Switch,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import ReloadNodeTemplatesButton from 'features/nodes/components/flow/panels/TopRightPanel/ReloadSchemaButton';
|
||||
import {
|
||||
selectionModeChanged,
|
||||
@@ -32,20 +32,17 @@ import {
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
} from 'features/nodes/store/workflowSettingsSlice';
|
||||
import type { ChangeEvent, ReactNode } from 'react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SelectionMode } from 'reactflow';
|
||||
|
||||
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
||||
export const [useWorkflowEditorSettingsModal] = buildUseBoolean(false);
|
||||
|
||||
type Props = {
|
||||
children: (props: { onOpen: () => void }) => ReactNode;
|
||||
};
|
||||
|
||||
const WorkflowEditorSettings = ({ children }: Props) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const WorkflowEditorSettings = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
|
||||
const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid);
|
||||
const selectionMode = useAppSelector(selectSelectionMode);
|
||||
@@ -99,76 +96,72 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<>
|
||||
{children({ onOpen })}
|
||||
|
||||
<Modal isOpen={isOpen} onClose={onClose} size="2xl" isCentered useInert={false}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t('nodes.workflowSettings')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Flex flexDirection="column" gap={4} py={4}>
|
||||
<Heading size="sm">{t('parameters.general')}</Heading>
|
||||
<FormControlGroup orientation="vertical" formLabelProps={formLabelProps}>
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.animatedEdges')}</FormLabel>
|
||||
<Switch onChange={handleChangeShouldAnimate} isChecked={shouldAnimateEdges} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.animatedEdgesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.snapToGrid')}</FormLabel>
|
||||
<Switch isChecked={shouldSnapToGrid} onChange={handleChangeShouldSnap} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.snapToGridHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.colorCodeEdges')}</FormLabel>
|
||||
<Switch isChecked={shouldColorEdges} onChange={handleChangeShouldColor} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.colorCodeEdgesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.fullyContainNodes')}</FormLabel>
|
||||
<Switch isChecked={selectionMode === SelectionMode.Full} onChange={handleChangeSelectionMode} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
|
||||
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<Heading size="sm" pt={4}>
|
||||
{t('common.advanced')}
|
||||
</Heading>
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.validateConnections')}</FormLabel>
|
||||
<Switch isChecked={shouldValidateGraph} onChange={handleChangeShouldValidate} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.validateConnectionsHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
</FormControlGroup>
|
||||
<ReloadNodeTemplatesButton />
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
<Modal isOpen={modal.isTrue} onClose={modal.setFalse} size="2xl" isCentered useInert={false}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t('nodes.workflowSettings')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Flex flexDirection="column" gap={4} py={4}>
|
||||
<Heading size="sm">{t('parameters.general')}</Heading>
|
||||
<FormControlGroup orientation="vertical" formLabelProps={formLabelProps}>
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.animatedEdges')}</FormLabel>
|
||||
<Switch onChange={handleChangeShouldAnimate} isChecked={shouldAnimateEdges} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.animatedEdgesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.snapToGrid')}</FormLabel>
|
||||
<Switch isChecked={shouldSnapToGrid} onChange={handleChangeShouldSnap} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.snapToGridHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.colorCodeEdges')}</FormLabel>
|
||||
<Switch isChecked={shouldColorEdges} onChange={handleChangeShouldColor} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.colorCodeEdgesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.fullyContainNodes')}</FormLabel>
|
||||
<Switch isChecked={selectionMode === SelectionMode.Full} onChange={handleChangeSelectionMode} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
|
||||
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
<Heading size="sm" pt={4}>
|
||||
{t('common.advanced')}
|
||||
</Heading>
|
||||
<FormControl>
|
||||
<Flex w="full">
|
||||
<FormLabel>{t('nodes.validateConnections')}</FormLabel>
|
||||
<Switch isChecked={shouldValidateGraph} onChange={handleChangeShouldValidate} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('nodes.validateConnectionsHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
<Divider />
|
||||
</FormControlGroup>
|
||||
<ReloadNodeTemplatesButton />
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import 'reactflow/dist/style.css';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import QueueControls from 'features/queue/components/QueueControls';
|
||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
|
||||
import WorkflowLibraryButton from 'features/workflowLibrary/components/WorkflowLibraryButton';
|
||||
@@ -34,7 +33,6 @@ const NodeEditorPanelGroup = () => {
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" gap={2} flexDir="column">
|
||||
<QueueControls />
|
||||
<Flex w="full" justifyContent="space-between" alignItems="center" gap="4" padding={1}>
|
||||
<Flex justifyContent="space-between" alignItems="center" gap="4">
|
||||
<WorkflowLibraryButton />
|
||||
|
||||
@@ -2,12 +2,12 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import {
|
||||
$addNodeCmdk,
|
||||
$didUpdateEdge,
|
||||
$edgePendingUpdate,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
edgesChanged,
|
||||
useAddNodeCmdk,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
@@ -21,6 +21,7 @@ export const useConnection = () => {
|
||||
const store = useAppStore();
|
||||
const templates = useStore($templates);
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const addNodeCmdk = useAddNodeCmdk();
|
||||
|
||||
const onConnectStart = useCallback<OnConnectStart>(
|
||||
(event, { nodeId, handleId, handleType }) => {
|
||||
@@ -107,9 +108,9 @@ export const useConnection = () => {
|
||||
$pendingConnection.set(null);
|
||||
} else {
|
||||
// The mouse is not over a node - we should open the add node popover
|
||||
$addNodeCmdk.set(true);
|
||||
addNodeCmdk.setTrue();
|
||||
}
|
||||
}, [store, templates, updateNodeInternals]);
|
||||
}, [addNodeCmdk, store, templates, updateNodeInternals]);
|
||||
|
||||
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||
return api;
|
||||
|
||||
@@ -7,11 +7,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
BooleanFieldValue,
|
||||
CLIPEmbedModelFieldValue,
|
||||
ColorFieldValue,
|
||||
ControlNetModelFieldValue,
|
||||
EnumFieldValue,
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
@@ -30,10 +32,12 @@ import type {
|
||||
import {
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
@@ -347,6 +351,12 @@ export const nodesSlice = createSlice({
|
||||
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
|
||||
},
|
||||
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@@ -409,6 +419,8 @@ export const {
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
@@ -432,8 +444,7 @@ export const $didUpdateEdge = atom(false);
|
||||
export const $lastEdgeUpdateMouseEvent = atom<MouseEvent | null>(null);
|
||||
|
||||
export const $viewport = atom<Viewport>({ x: 0, y: 0, zoom: 1 });
|
||||
export const $addNodeCmdk = atom(false);
|
||||
export const useAddNodeCmdk = buildUseBoolean($addNodeCmdk);
|
||||
export const [useAddNodeCmdk, $addNodeCmdk] = buildUseBoolean(false);
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateNodesState = (state: any): any => {
|
||||
@@ -514,6 +525,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodesChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
|
||||
@@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T5EncoderModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FluxVAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zT5EncoderModelFieldType,
|
||||
zCLIPEmbedModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
@@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
|
||||
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
||||
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregio
|
||||
// #endregion
|
||||
|
||||
// #region FluxVAEModelField
|
||||
|
||||
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFluxVAEModelFieldValue,
|
||||
});
|
||||
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zFluxVAEModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFluxVAEModelFieldValue,
|
||||
});
|
||||
|
||||
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||
|
||||
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
|
||||
zFluxVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
|
||||
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region CLIPEmbedModelField
|
||||
|
||||
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
|
||||
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zCLIPEmbedModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zCLIPEmbedModelFieldValue,
|
||||
});
|
||||
|
||||
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
|
||||
|
||||
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
|
||||
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
|
||||
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
@@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zT5EncoderModelFieldInputInstance,
|
||||
zFluxVAEModelFieldInputInstance,
|
||||
zCLIPEmbedModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
@@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zT5EncoderModelFieldInputTemplate,
|
||||
zFluxVAEModelFieldInputTemplate,
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
|
||||
@@ -10,7 +10,9 @@ export const prepareLinearUIBatch = (
|
||||
g: Graph,
|
||||
prepend: boolean,
|
||||
noise: Invocation<'noise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt'>
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt'>,
|
||||
origin: 'generation' | 'workflows' | 'upscaling',
|
||||
destination: 'canvas' | 'gallery'
|
||||
): BatchConfig => {
|
||||
const { iterations, model, shouldRandomizeSeed, seed, shouldConcatPrompts } = state.params;
|
||||
const { prompts, seedBehaviour } = state.dynamicPrompts;
|
||||
@@ -103,7 +105,8 @@ export const prepareLinearUIBatch = (
|
||||
graph: g.getGraph(),
|
||||
runs: 1,
|
||||
data,
|
||||
origin: 'canvas',
|
||||
origin,
|
||||
destination,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ export const addInpaint = async (
|
||||
const canvas = selectCanvasSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
const { mode } = canvasSession;
|
||||
const { sendToCanvas: isComposing } = canvasSession;
|
||||
|
||||
const initialImage = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
|
||||
const maskImage = await manager.compositor.getCompositeInpaintMaskImageDTO(bbox.rect);
|
||||
@@ -99,7 +99,7 @@ export const addInpaint = async (
|
||||
g.addEdge(resizeImageToOriginalSize, 'image', canvasPasteBack, 'generated_image');
|
||||
g.addEdge(resizeMaskToOriginalSize, 'image', canvasPasteBack, 'mask');
|
||||
|
||||
if (mode === 'generate') {
|
||||
if (!isComposing) {
|
||||
canvasPasteBack.source_image = { image_name: initialImage.image_name };
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ export const addInpaint = async (
|
||||
|
||||
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
|
||||
|
||||
if (mode === 'generate') {
|
||||
if (!isComposing) {
|
||||
canvasPasteBack.source_image = { image_name: initialImage.image_name };
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ export const addOutpaint = async (
|
||||
const canvas = selectCanvasSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
const { mode } = canvasSession;
|
||||
const { sendToCanvas: isComposing } = canvasSession;
|
||||
|
||||
const initialImage = await manager.compositor.getCompositeRasterLayerImageDTO(bbox.rect);
|
||||
const maskImage = await manager.compositor.getCompositeInpaintMaskImageDTO(bbox.rect);
|
||||
@@ -123,7 +123,7 @@ export const addOutpaint = async (
|
||||
g.addEdge(resizeOutputImageToOriginalSize, 'image', canvasPasteBack, 'generated_image');
|
||||
g.addEdge(resizeOutputMaskToOriginalSize, 'image', canvasPasteBack, 'mask');
|
||||
|
||||
if (mode === 'generate') {
|
||||
if (!isComposing) {
|
||||
canvasPasteBack.source_image = { image_name: initialImage.image_name };
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ export const addOutpaint = async (
|
||||
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
|
||||
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
|
||||
|
||||
if (mode === 'generate') {
|
||||
if (!isComposing) {
|
||||
canvasPasteBack.source_image = { image_name: initialImage.image_name };
|
||||
}
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ export const buildSD1Graph = async (
|
||||
canvasOutput = addWatermarker(g, canvasOutput);
|
||||
}
|
||||
|
||||
const shouldSaveToGallery = canvasSession.mode === 'generate' || canvasSettings.autoSave;
|
||||
const shouldSaveToGallery = !canvasSession.sendToCanvas || canvasSettings.autoSave;
|
||||
|
||||
g.updateNode(canvasOutput, {
|
||||
id: getPrefixedId('canvas_output'),
|
||||
|
||||
@@ -285,7 +285,7 @@ export const buildSDXLGraph = async (
|
||||
canvasOutput = addWatermarker(g, canvasOutput);
|
||||
}
|
||||
|
||||
const shouldSaveToGallery = canvasSession.mode === 'generate' || canvasSettings.autoSave;
|
||||
const shouldSaveToGallery = !canvasSession.sendToCanvas || canvasSettings.autoSave;
|
||||
|
||||
g.updateNode(canvasOutput, {
|
||||
id: getPrefixedId('canvas_output'),
|
||||
|
||||
@@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
T5EncoderModelField: undefined,
|
||||
FluxVAEModelField: undefined,
|
||||
CLIPEmbedModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { FieldParseError } from 'features/nodes/types/error';
|
||||
import type {
|
||||
BoardFieldInputTemplate,
|
||||
BooleanFieldInputTemplate,
|
||||
CLIPEmbedModelFieldInputTemplate,
|
||||
ColorFieldInputTemplate,
|
||||
ControlNetModelFieldInputTemplate,
|
||||
EnumFieldInputTemplate,
|
||||
@@ -9,6 +10,7 @@ import type {
|
||||
FieldType,
|
||||
FloatFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
@@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: CLIPEmbedModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FluxVAEModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
|
||||
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useClearQueueConfirmationAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
|
||||
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
|
||||
|
||||
type Props = ButtonProps;
|
||||
|
||||
const ClearQueueButton = (props: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dialogState = useClearQueueConfirmationAlertDialog();
|
||||
const { isLoading, isDisabled } = useClearQueue();
|
||||
const clearQueue = useClearQueue();
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
isDisabled={clearQueue.isDisabled}
|
||||
isLoading={clearQueue.isLoading}
|
||||
tooltip={t('queue.clearTooltip')}
|
||||
leftIcon={<PiTrashSimpleFill />}
|
||||
colorScheme="error"
|
||||
onClick={dialogState.setTrue}
|
||||
onClick={clearQueue.openDialog}
|
||||
data-testid={t('queue.clear')}
|
||||
{...props}
|
||||
>
|
||||
|
||||
@@ -1,26 +1,72 @@
|
||||
import { ConfirmationAlertDialog, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo } from 'react';
|
||||
import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
const $boolean = atom(false);
|
||||
export const useClearQueueConfirmationAlertDialog = buildUseBoolean($boolean);
|
||||
const [useClearQueueConfirmationAlertDialog] = buildUseBoolean(false);
|
||||
|
||||
export const useClearQueue = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const dialog = useClearQueueConfirmationAlertDialog();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const isConnected = useStore($isConnected);
|
||||
const [trigger, { isLoading }] = useClearQueueMutation({
|
||||
fixedCacheKey: 'clearQueue',
|
||||
});
|
||||
|
||||
const clearQueue = useCallback(async () => {
|
||||
if (!queueStatus?.queue.total) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await trigger().unwrap();
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_SUCCEEDED',
|
||||
title: t('queue.clearSucceeded'),
|
||||
status: 'success',
|
||||
});
|
||||
dispatch(listCursorChanged(undefined));
|
||||
dispatch(listPriorityChanged(undefined));
|
||||
} catch {
|
||||
toast({
|
||||
id: 'QUEUE_CLEAR_FAILED',
|
||||
title: t('queue.clearFailed'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [queueStatus?.queue.total, trigger, dispatch, t]);
|
||||
|
||||
const isDisabled = useMemo(() => !isConnected || !queueStatus?.queue.total, [isConnected, queueStatus?.queue.total]);
|
||||
|
||||
return {
|
||||
clearQueue,
|
||||
isOpen: dialog.isTrue,
|
||||
openDialog: dialog.setTrue,
|
||||
closeDialog: dialog.setFalse,
|
||||
isLoading,
|
||||
queueStatus,
|
||||
isDisabled,
|
||||
};
|
||||
};
|
||||
|
||||
export const ClearQueueConfirmationsAlertDialog = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dialogState = useClearQueueConfirmationAlertDialog();
|
||||
const isOpen = useStore(dialogState.$boolean);
|
||||
const { clearQueue } = useClearQueue();
|
||||
const clearQueue = useClearQueue();
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={isOpen}
|
||||
onClose={dialogState.setFalse}
|
||||
isOpen={clearQueue.isOpen}
|
||||
onClose={clearQueue.closeDialog}
|
||||
title={t('queue.clearTooltip')}
|
||||
acceptCallback={clearQueue}
|
||||
acceptCallback={clearQueue.clearQueue}
|
||||
acceptButtonText={t('queue.clear')}
|
||||
useInert={false}
|
||||
>
|
||||
|
||||
@@ -1,67 +1,40 @@
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useClearQueueConfirmationAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { QueueCountBadge } from 'features/queue/components/QueueCountBadge';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { memo } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
type ClearQueueButtonProps = Omit<IconButtonProps, 'aria-label'>;
|
||||
import { useClearQueue } from './ClearQueueConfirmationAlertDialog';
|
||||
|
||||
export const ClearAllQueueIconButton = memo((props: ClearQueueButtonProps) => {
|
||||
export const ClearQueueIconButton = memo((_) => {
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
const { t } = useTranslation();
|
||||
const dialogState = useClearQueueConfirmationAlertDialog();
|
||||
const { isLoading, isDisabled } = useClearQueue();
|
||||
const clearQueue = useClearQueue();
|
||||
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
aria-label={t('queue.clear')}
|
||||
tooltip={t('queue.clearTooltip')}
|
||||
icon={<PiTrashSimpleBold size="16px" />}
|
||||
colorScheme="error"
|
||||
onClick={dialogState.setTrue}
|
||||
data-testid={t('queue.clear')}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
ClearAllQueueIconButton.displayName = 'ClearAllQueueIconButton';
|
||||
|
||||
const ClearSingleQueueItemIconButton = memo((props: ClearQueueButtonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { cancelQueueItem, isLoading, isDisabled } = useCancelCurrentQueueItem();
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
isDisabled={isDisabled}
|
||||
isLoading={isLoading}
|
||||
aria-label={t('queue.cancel')}
|
||||
tooltip={t('queue.cancelTooltip')}
|
||||
icon={<PiXBold size="16px" />}
|
||||
colorScheme="error"
|
||||
onClick={cancelQueueItem}
|
||||
data-testid={t('queue.cancel')}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
ClearSingleQueueItemIconButton.displayName = 'ClearSingleQueueItemIconButton';
|
||||
|
||||
export const ClearQueueIconButton = memo((props: ClearQueueButtonProps) => {
|
||||
// Show the single item clear button when shift is pressed
|
||||
// Otherwise show the clear queue button
|
||||
const shift = useShiftModifier();
|
||||
|
||||
if (shift) {
|
||||
return <ClearAllQueueIconButton {...props} />;
|
||||
}
|
||||
|
||||
return <ClearSingleQueueItemIconButton {...props} />;
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
ref={ref}
|
||||
size="lg"
|
||||
isDisabled={shift ? clearQueue.isDisabled : cancelCurrentQueueItem.isDisabled}
|
||||
isLoading={shift ? clearQueue.isLoading : cancelCurrentQueueItem.isLoading}
|
||||
aria-label={shift ? t('queue.clear') : t('queue.cancel')}
|
||||
tooltip={shift ? t('queue.clearTooltip') : t('queue.cancelTooltip')}
|
||||
icon={shift ? <PiTrashSimpleBold /> : <PiXBold />}
|
||||
colorScheme="error"
|
||||
onClick={shift ? clearQueue.openDialog : cancelCurrentQueueItem.cancelQueueItem}
|
||||
data-testid={shift ? t('queue.clear') : t('queue.cancel')}
|
||||
/>
|
||||
{/* The badge is dynamically positioned, needs a ref to the target element */}
|
||||
<QueueCountBadge targetRef={ref} />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
ClearQueueIconButton.displayName = 'ClearQueueIconButton';
|
||||
|
||||
@@ -15,7 +15,7 @@ export const InvokeQueueBackButton = memo(() => {
|
||||
const isLoadingDynamicPrompts = useAppSelector(selectDynamicPromptsIsLoading);
|
||||
|
||||
return (
|
||||
<Flex pos="relative" flexGrow={1} minW="240px">
|
||||
<Flex pos="relative" w="192px">
|
||||
<QueueIterationsNumberInput />
|
||||
<QueueButtonTooltip>
|
||||
<Button
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
import {
|
||||
Badge,
|
||||
Box,
|
||||
IconButton,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuDivider,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Portal,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { Coordinate } from 'features/controlLayers/store/types';
|
||||
import { useClearQueueConfirmationAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { usePauseProcessor } from 'features/queue/hooks/usePauseProcessor';
|
||||
import { useResumeProcessor } from 'features/queue/hooks/useResumeProcessor';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPauseFill, PiPlayFill, PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { RiListCheck, RiPlayList2Fill } from 'react-icons/ri';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
export const QueueActionsMenuButton = memo(() => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [badgePos, setBadgePos] = useState<Coordinate | null>(null);
|
||||
const menuButtonRef = useRef<HTMLButtonElement>(null);
|
||||
const dialogState = useClearQueueConfirmationAlertDialog();
|
||||
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||
selectFromResult: (res) => ({
|
||||
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||
}),
|
||||
});
|
||||
const { isLoading: isLoadingClearQueue, isDisabled: isDisabledClearQueue } = useClearQueue();
|
||||
const {
|
||||
resumeProcessor,
|
||||
isLoading: isLoadingResumeProcessor,
|
||||
isDisabled: isDisabledResumeProcessor,
|
||||
} = useResumeProcessor();
|
||||
const {
|
||||
pauseProcessor,
|
||||
isLoading: isLoadingPauseProcessor,
|
||||
isDisabled: isDisabledPauseProcessor,
|
||||
} = usePauseProcessor();
|
||||
const openQueue = useCallback(() => {
|
||||
dispatch(setActiveTab('queue'));
|
||||
}, [dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
if (menuButtonRef.current) {
|
||||
const { x, y } = menuButtonRef.current.getBoundingClientRect();
|
||||
setBadgePos({ x: x - 10, y: y - 10 });
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Box pos="relative">
|
||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose} placement="bottom-end">
|
||||
<MenuButton ref={menuButtonRef} as={IconButton} aria-label="Queue Actions Menu" icon={<RiListCheck />} />
|
||||
<MenuList>
|
||||
<MenuItem
|
||||
isDestructive
|
||||
icon={<PiTrashSimpleBold size="16px" />}
|
||||
onClick={dialogState.setTrue}
|
||||
isLoading={isLoadingClearQueue}
|
||||
isDisabled={isDisabledClearQueue}
|
||||
>
|
||||
{t('queue.clearTooltip')}
|
||||
</MenuItem>
|
||||
{isResumeEnabled && (
|
||||
<MenuItem
|
||||
icon={<PiPlayFill size="14px" />}
|
||||
onClick={resumeProcessor}
|
||||
isLoading={isLoadingResumeProcessor}
|
||||
isDisabled={isDisabledResumeProcessor}
|
||||
>
|
||||
{t('queue.resumeTooltip')}
|
||||
</MenuItem>
|
||||
)}
|
||||
{isPauseEnabled && (
|
||||
<MenuItem
|
||||
icon={<PiPauseFill size="14px" />}
|
||||
onClick={pauseProcessor}
|
||||
isLoading={isLoadingPauseProcessor}
|
||||
isDisabled={isDisabledPauseProcessor}
|
||||
>
|
||||
{t('queue.pauseTooltip')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuDivider />
|
||||
<MenuItem icon={<RiPlayList2Fill />} onClick={openQueue}>
|
||||
{t('queue.openQueue')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{queueSize > 0 && badgePos !== null && (
|
||||
<Portal>
|
||||
<Badge
|
||||
pos="absolute"
|
||||
insetInlineStart={badgePos.x}
|
||||
insetBlockStart={badgePos.y}
|
||||
colorScheme="invokeYellow"
|
||||
zIndex="docked"
|
||||
>
|
||||
{queueSize}
|
||||
</Badge>
|
||||
</Portal>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
QueueActionsMenuButton.displayName = 'QueueActionsMenuButton';
|
||||
@@ -1,24 +1,27 @@
|
||||
import { ButtonGroup, Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasSendToToggle } from 'features/controlLayers/components/CanvasSendToToggle';
|
||||
import { ClearQueueIconButton } from 'features/queue/components/ClearQueueIconButton';
|
||||
import QueueFrontButton from 'features/queue/components/QueueFrontButton';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { InvokeQueueBackButton } from './InvokeQueueBackButton';
|
||||
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
|
||||
|
||||
const QueueControls = () => {
|
||||
const isPrependEnabled = useFeatureStatus('prependQueue');
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
return (
|
||||
<Flex w="full" position="relative" borderRadius="base" gap={2} flexDir="column">
|
||||
<ButtonGroup size="lg" isAttached={false}>
|
||||
<Flex gap={2}>
|
||||
{isPrependEnabled && <QueueFrontButton />}
|
||||
<InvokeQueueBackButton />
|
||||
<Spacer />
|
||||
<QueueActionsMenuButton />
|
||||
{tab === 'generation' && <CanvasSendToToggle />}
|
||||
<ClearQueueIconButton />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<ProgressBar />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import { Badge, Portal } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $isParametersPanelOpen } from 'features/ui/store/uiSlice';
|
||||
import type { RefObject } from 'react';
|
||||
import { memo, useEffect, useState } from 'react';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
type Props = {
|
||||
targetRef: RefObject<HTMLDivElement>;
|
||||
};
|
||||
|
||||
export const QueueCountBadge = memo(({ targetRef }: Props) => {
|
||||
const [badgePos, setBadgePos] = useState<{ x: string; y: string } | null>(null);
|
||||
const isParametersPanelOpen = useStore($isParametersPanelOpen);
|
||||
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||
selectFromResult: (res) => ({
|
||||
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||
}),
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!targetRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const target = targetRef.current;
|
||||
const parent = target.parentElement;
|
||||
|
||||
if (!parent) {
|
||||
return;
|
||||
}
|
||||
|
||||
const cb = () => {
|
||||
if (!$isParametersPanelOpen.get()) {
|
||||
return;
|
||||
}
|
||||
const { x, y } = target.getBoundingClientRect();
|
||||
setBadgePos({ x: `${x - 7}px`, y: `${y - 5}px` });
|
||||
};
|
||||
|
||||
const resizeObserver = new ResizeObserver(cb);
|
||||
resizeObserver.observe(parent);
|
||||
cb();
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [targetRef]);
|
||||
|
||||
if (queueSize === 0) {
|
||||
return null;
|
||||
}
|
||||
if (!badgePos) {
|
||||
return null;
|
||||
}
|
||||
if (!isParametersPanelOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Portal>
|
||||
<Badge
|
||||
pos="absolute"
|
||||
insetInlineStart={badgePos.x}
|
||||
insetBlockStart={badgePos.y}
|
||||
colorScheme="invokeYellow"
|
||||
zIndex="docked"
|
||||
shadow="dark-lg"
|
||||
userSelect="none"
|
||||
>
|
||||
{queueSize}
|
||||
</Badge>
|
||||
</Portal>
|
||||
);
|
||||
});
|
||||
|
||||
QueueCountBadge.displayName = 'QueueCountBadge';
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { ChakraProps, CollapseProps } from '@invoke-ai/ui-library';
|
||||
import { ButtonGroup, Collapse, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import QueueStatusBadge from 'features/queue/components/common/QueueStatusBadge';
|
||||
import { useDestinationText } from 'features/queue/components/QueueList/useDestinationText';
|
||||
import { useOriginText } from 'features/queue/components/QueueList/useOriginText';
|
||||
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
||||
import { getSecondsFromTimestamps } from 'features/queue/util/getSecondsFromTimestamps';
|
||||
@@ -52,6 +53,7 @@ const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
|
||||
|
||||
const isCanceled = useMemo(() => ['canceled', 'completed', 'failed'].includes(item.status), [item.status]);
|
||||
const originText = useOriginText(item.origin);
|
||||
const destinationText = useDestinationText(item.destination);
|
||||
|
||||
const icon = useMemo(() => <PiXBold />, []);
|
||||
return (
|
||||
@@ -76,6 +78,11 @@ const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
|
||||
{originText}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.destination} flexShrink={0}>
|
||||
<Text overflow="hidden" textOverflow="ellipsis" whiteSpace="nowrap" alignItems="center">
|
||||
{destinationText}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.time} alignItems="center" flexShrink={0}>
|
||||
{executionTime || '-'}
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Button, ButtonGroup, Flex, Heading, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useDestinationText } from 'features/queue/components/QueueList/useDestinationText';
|
||||
import { useOriginText } from 'features/queue/components/QueueList/useOriginText';
|
||||
import { useCancelBatch } from 'features/queue/hooks/useCancelBatch';
|
||||
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
||||
@@ -17,7 +18,7 @@ type Props = {
|
||||
};
|
||||
|
||||
const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
||||
const { session_id, batch_id, item_id, origin } = queueItemDTO;
|
||||
const { session_id, batch_id, item_id, origin, destination } = queueItemDTO;
|
||||
const { t } = useTranslation();
|
||||
const { cancelBatch, isLoading: isLoadingCancelBatch, isCanceled } = useCancelBatch(batch_id);
|
||||
|
||||
@@ -26,6 +27,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
||||
const { data: queueItem } = useGetQueueItemQuery(item_id);
|
||||
|
||||
const originText = useOriginText(origin);
|
||||
const destinationText = useDestinationText(destination);
|
||||
|
||||
const statusAndTiming = useMemo(() => {
|
||||
if (!queueItem) {
|
||||
@@ -54,6 +56,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
||||
>
|
||||
<QueueItemData label={t('queue.status')} data={statusAndTiming} />
|
||||
<QueueItemData label={t('queue.origin')} data={originText} />
|
||||
<QueueItemData label={t('queue.destination')} data={destinationText} />
|
||||
<QueueItemData label={t('queue.item')} data={item_id} />
|
||||
<QueueItemData label={t('queue.batch')} data={batch_id} />
|
||||
<QueueItemData label={t('queue.session')} data={session_id} />
|
||||
|
||||
@@ -25,6 +25,9 @@ const QueueListHeader = () => {
|
||||
<Flex ps={0.5} w={COLUMN_WIDTHS.origin} alignItems="center">
|
||||
<Text variant="subtext">{t('queue.origin')}</Text>
|
||||
</Flex>
|
||||
<Flex ps={0.5} w={COLUMN_WIDTHS.destination} alignItems="center">
|
||||
<Text variant="subtext">{t('queue.destination')}</Text>
|
||||
</Flex>
|
||||
<Flex ps={0.5} w={COLUMN_WIDTHS.time} alignItems="center">
|
||||
<Text variant="subtext">{t('queue.time')}</Text>
|
||||
</Flex>
|
||||
|
||||
@@ -4,7 +4,8 @@ export const COLUMN_WIDTHS = {
|
||||
statusDot: 2,
|
||||
time: '4rem',
|
||||
origin: '5rem',
|
||||
destination: '6rem',
|
||||
batchId: '5rem',
|
||||
fieldValues: 'auto',
|
||||
actions: 'auto',
|
||||
};
|
||||
} as const;
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { SessionQueueItemDTO } from 'services/api/types';
|
||||
|
||||
export const useDestinationText = (destination: SessionQueueItemDTO['destination']) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (destination === 'canvas') {
|
||||
return t('queue.canvas');
|
||||
}
|
||||
|
||||
if (destination === 'gallery') {
|
||||
return t('queue.gallery');
|
||||
}
|
||||
|
||||
return t('queue.other');
|
||||
};
|
||||
@@ -4,13 +4,17 @@ import type { SessionQueueItemDTO } from 'services/api/types';
|
||||
export const useOriginText = (origin: SessionQueueItemDTO['origin']) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (origin === 'canvas') {
|
||||
return t('queue.originCanvas');
|
||||
if (origin === 'generation') {
|
||||
return t('queue.generation');
|
||||
}
|
||||
|
||||
if (origin === 'workflows') {
|
||||
return t('queue.originWorkflows');
|
||||
return t('queue.workflows');
|
||||
}
|
||||
|
||||
return t('queue.originOther');
|
||||
if (origin === 'upscaling') {
|
||||
return t('queue.upscaling');
|
||||
}
|
||||
|
||||
return t('queue.other');
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user