mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 10:37:55 -05:00
Compare commits
678 Commits
ryan/flux-
...
v5.0.0.dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8dadc8c0a | ||
|
|
d421eaeff0 | ||
|
|
d434e80e99 | ||
|
|
918ed5c264 | ||
|
|
2f98c876df | ||
|
|
2b9f9f3d68 | ||
|
|
f326bfe1b0 | ||
|
|
c57966ec38 | ||
|
|
3c5bf48810 | ||
|
|
c0db830398 | ||
|
|
3e3e00b24d | ||
|
|
640af85c93 | ||
|
|
81060eef71 | ||
|
|
ea193903ee | ||
|
|
5a251a4876 | ||
|
|
ce93fdd076 | ||
|
|
788e03fbc3 | ||
|
|
8cb1584379 | ||
|
|
63a2b6cdfe | ||
|
|
b193bc7b14 | ||
|
|
dce488c691 | ||
|
|
615b09d453 | ||
|
|
1dd521a792 | ||
|
|
f6a703c367 | ||
|
|
5a69b62a1a | ||
|
|
c4b9ba8e12 | ||
|
|
c04790158f | ||
|
|
a0ed7baf73 | ||
|
|
f96148ca52 | ||
|
|
96786ed62b | ||
|
|
cc9cc62707 | ||
|
|
07cf5807be | ||
|
|
013414e962 | ||
|
|
10bba49e28 | ||
|
|
7888b4913a | ||
|
|
27b00cedc5 | ||
|
|
0bfdd5acb8 | ||
|
|
6e6629ad74 | ||
|
|
fe257e3e46 | ||
|
|
0b1dc365d0 | ||
|
|
c75a933b17 | ||
|
|
f1525c2a82 | ||
|
|
8266c10c53 | ||
|
|
34554f8555 | ||
|
|
200dbbbf55 | ||
|
|
5a30ff6045 | ||
|
|
73e8837d45 | ||
|
|
6ba78e7632 | ||
|
|
736fabe327 | ||
|
|
ad53169eba | ||
|
|
0a22e6b5a1 | ||
|
|
5eb9e2ba04 | ||
|
|
376a626419 | ||
|
|
5d089ecc77 | ||
|
|
3ed85d821b | ||
|
|
7c29caa245 | ||
|
|
6161c71aa7 | ||
|
|
f7e68ac6d6 | ||
|
|
5074c97683 | ||
|
|
bb6b5c4941 | ||
|
|
ff12f16343 | ||
|
|
397038e6e8 | ||
|
|
fcacc0f70a | ||
|
|
3ed4cf63ae | ||
|
|
4b753941e4 | ||
|
|
dbff929be5 | ||
|
|
44be058ac7 | ||
|
|
437129cf70 | ||
|
|
2e4a48b71c | ||
|
|
618473db02 | ||
|
|
4f4eee76b9 | ||
|
|
1628eb5900 | ||
|
|
92c754df79 | ||
|
|
921ade4b0a | ||
|
|
c38b05c002 | ||
|
|
277708d69e | ||
|
|
02061faaf0 | ||
|
|
a4fb8eb171 | ||
|
|
d3c2432bfb | ||
|
|
0c3e2c45f1 | ||
|
|
d019dbdbd3 | ||
|
|
2538b348c7 | ||
|
|
420178c09f | ||
|
|
75ecc56ca5 | ||
|
|
874b96c67d | ||
|
|
b0eddd19d4 | ||
|
|
cb9d0bce56 | ||
|
|
fd6d080fb0 | ||
|
|
8f6e6960e0 | ||
|
|
58064a1ea5 | ||
|
|
0d94a89d98 | ||
|
|
eb4d447e2f | ||
|
|
487422a53f | ||
|
|
b24f8e2b26 | ||
|
|
fd22ff7c9f | ||
|
|
9cd5107955 | ||
|
|
fd7f0056e3 | ||
|
|
b1673b9222 | ||
|
|
f4e8799561 | ||
|
|
6f85743996 | ||
|
|
243e7c954e | ||
|
|
33d25c7977 | ||
|
|
d11251f380 | ||
|
|
d5e5cae398 | ||
|
|
fa80452abf | ||
|
|
c9e396bee9 | ||
|
|
fd6de78746 | ||
|
|
6096c8ed0a | ||
|
|
5126ec114a | ||
|
|
221c51b498 | ||
|
|
28db39a932 | ||
|
|
5512f80066 | ||
|
|
70e27d6fbb | ||
|
|
e91149111e | ||
|
|
ad7919b3ec | ||
|
|
908f4117fc | ||
|
|
87b294708a | ||
|
|
a0be7c47cd | ||
|
|
87bc3f5f0a | ||
|
|
14dea8bad3 | ||
|
|
9ced9efc64 | ||
|
|
b0414a81ec | ||
|
|
f078e325f8 | ||
|
|
1a5dc202cb | ||
|
|
79d7f023e8 | ||
|
|
913912d96b | ||
|
|
cfa22b345a | ||
|
|
ce4948ec23 | ||
|
|
86d10ca87b | ||
|
|
057b5ec646 | ||
|
|
16eede4288 | ||
|
|
e3c0c638c1 | ||
|
|
e57ee8db35 | ||
|
|
cd2a80cf77 | ||
|
|
300ecd9b16 | ||
|
|
a95fbfe6c6 | ||
|
|
79168b637e | ||
|
|
42c7ddebaa | ||
|
|
ac349573f2 | ||
|
|
28feb013d7 | ||
|
|
4129cddbeb | ||
|
|
4463e6da5d | ||
|
|
db0b97cca1 | ||
|
|
be58b03137 | ||
|
|
2df0056ef1 | ||
|
|
eeca521baf | ||
|
|
93b9792096 | ||
|
|
a8079e5854 | ||
|
|
61f9ba5b46 | ||
|
|
a99ff467f6 | ||
|
|
5abad87ce3 | ||
|
|
931a4ddc88 | ||
|
|
784a615864 | ||
|
|
697c81b74e | ||
|
|
d6b8cb9e7c | ||
|
|
f4a031a412 | ||
|
|
1749abbd97 | ||
|
|
1497afcfda | ||
|
|
ec9fdace22 | ||
|
|
c1a039ef91 | ||
|
|
69d1edc036 | ||
|
|
b69b001755 | ||
|
|
97414f1886 | ||
|
|
478daf8614 | ||
|
|
c74b130303 | ||
|
|
11bc318d8d | ||
|
|
77125bee7e | ||
|
|
4354cd7c38 | ||
|
|
0e31ccaea0 | ||
|
|
e32fa8bd35 | ||
|
|
d209652caa | ||
|
|
fe672ba5e0 | ||
|
|
04d41085a3 | ||
|
|
36604f752e | ||
|
|
9056f446bb | ||
|
|
a33d1b979d | ||
|
|
4670f82e65 | ||
|
|
a07346b364 | ||
|
|
2b1e930cdf | ||
|
|
6b75ea3b01 | ||
|
|
ff8bc93080 | ||
|
|
85eff566dd | ||
|
|
9480691de5 | ||
|
|
463f3dbb35 | ||
|
|
72f304baa5 | ||
|
|
b4d01659e0 | ||
|
|
a733d72089 | ||
|
|
8d0a75cb5f | ||
|
|
8a56702341 | ||
|
|
af638cf5ce | ||
|
|
839248c74c | ||
|
|
4b488de10e | ||
|
|
112d6ead91 | ||
|
|
a5e2e78dee | ||
|
|
b91f79b0bb | ||
|
|
bc75b0259b | ||
|
|
51663c5439 | ||
|
|
86a5e2538c | ||
|
|
ee94b0ce17 | ||
|
|
03b3139705 | ||
|
|
44b8480832 | ||
|
|
e6960320dd | ||
|
|
775353fe82 | ||
|
|
e013aff9fe | ||
|
|
a22bf4c296 | ||
|
|
12a00de2ec | ||
|
|
6cbaf7e0ae | ||
|
|
9b8d814082 | ||
|
|
20a0ed81c5 | ||
|
|
a040d1d2a6 | ||
|
|
7c68b889eb | ||
|
|
5ea3c11883 | ||
|
|
f355d907e6 | ||
|
|
da5b99c840 | ||
|
|
4f9c4f7b40 | ||
|
|
143c47c887 | ||
|
|
2ca5aeda96 | ||
|
|
9b1ee633e6 | ||
|
|
c06998184e | ||
|
|
84aa2b94e5 | ||
|
|
64f7c7bf36 | ||
|
|
18d9263d60 | ||
|
|
e0e8165992 | ||
|
|
e826f8a020 | ||
|
|
9486c50e67 | ||
|
|
66424c3c93 | ||
|
|
cff871e8a6 | ||
|
|
51cd435ad8 | ||
|
|
bc8bf989f3 | ||
|
|
87150b7c6b | ||
|
|
5797797904 | ||
|
|
ab7b9c4523 | ||
|
|
68bf4459c3 | ||
|
|
05052fd21c | ||
|
|
f4e3f87c2e | ||
|
|
f47e1afd57 | ||
|
|
aa4be01bdf | ||
|
|
f8d4d5e546 | ||
|
|
f628b8ad9d | ||
|
|
bb4eb70a4a | ||
|
|
7b1a533bf2 | ||
|
|
d32526f04a | ||
|
|
76bc2cf5db | ||
|
|
90b3ebb27d | ||
|
|
4b65d9145e | ||
|
|
9badbb8dff | ||
|
|
70bcca23a4 | ||
|
|
5ba15d0db5 | ||
|
|
fef70fdac0 | ||
|
|
62eb00aacf | ||
|
|
3f41ea574b | ||
|
|
e865c1a2b5 | ||
|
|
f4a798c64f | ||
|
|
7806604666 | ||
|
|
917f8087b4 | ||
|
|
b7a9a6153a | ||
|
|
fdc5f89060 | ||
|
|
7fc0959c83 | ||
|
|
95ff64d1d6 | ||
|
|
67f94007d0 | ||
|
|
930f27d009 | ||
|
|
6125e0fd13 | ||
|
|
c8e8cf9dfb | ||
|
|
ac8122a154 | ||
|
|
aa7c909f14 | ||
|
|
40d294b29c | ||
|
|
ff1fd5b5b5 | ||
|
|
6aa87d7df2 | ||
|
|
7732eeb815 | ||
|
|
9eb37feb6c | ||
|
|
e3b38339a0 | ||
|
|
f92695e677 | ||
|
|
9473d54fc7 | ||
|
|
33783eae4f | ||
|
|
f6c3570c26 | ||
|
|
96a76a763e | ||
|
|
209bc925c0 | ||
|
|
a77fb427d0 | ||
|
|
6fc358b9ea | ||
|
|
742ac3088d | ||
|
|
8b958e5280 | ||
|
|
71c727e586 | ||
|
|
c590e8bd6f | ||
|
|
71a13c3e91 | ||
|
|
d7e7f049e7 | ||
|
|
b23332e167 | ||
|
|
21c637c365 | ||
|
|
2b7aecb346 | ||
|
|
5c669f7925 | ||
|
|
748f445de5 | ||
|
|
2127ee9a09 | ||
|
|
7ccb1c5938 | ||
|
|
26d1a2ebfe | ||
|
|
e0bd683d88 | ||
|
|
3c5332b236 | ||
|
|
d4e847bd9a | ||
|
|
cee16536cd | ||
|
|
811e75dc67 | ||
|
|
a4a77d70f7 | ||
|
|
13e2de3cb4 | ||
|
|
7925bbf454 | ||
|
|
39eb0a01d2 | ||
|
|
305e50004a | ||
|
|
9974784596 | ||
|
|
105c5a7fd4 | ||
|
|
7baad9c72e | ||
|
|
a5e8705ea3 | ||
|
|
7b2a5c3a30 | ||
|
|
3dcb33026a | ||
|
|
e956ea5482 | ||
|
|
64f50ab278 | ||
|
|
b152937f30 | ||
|
|
4e6a9d990c | ||
|
|
d6fb220b2c | ||
|
|
13fc539f8b | ||
|
|
d2b5d6342c | ||
|
|
eb644d4e6a | ||
|
|
d8a2efc691 | ||
|
|
163687aef3 | ||
|
|
b41f1a897d | ||
|
|
006a5723da | ||
|
|
daede9c9cf | ||
|
|
af7d14cd59 | ||
|
|
ba14fe3600 | ||
|
|
af726d1f15 | ||
|
|
3a2efb351d | ||
|
|
c9dc61c311 | ||
|
|
500f151d96 | ||
|
|
9708fc5d6c | ||
|
|
c697501285 | ||
|
|
e213cfc2ba | ||
|
|
270c1304a8 | ||
|
|
b887cf4612 | ||
|
|
e0573b721e | ||
|
|
22712c5dac | ||
|
|
7339b3d8cc | ||
|
|
0eda34b41f | ||
|
|
9d9e845198 | ||
|
|
6b7ead4461 | ||
|
|
3ce3056c4a | ||
|
|
aa73cbf459 | ||
|
|
fa3f109eb9 | ||
|
|
c22afa5725 | ||
|
|
318086571d | ||
|
|
260ef8edd5 | ||
|
|
a3a933a797 | ||
|
|
d2d298604c | ||
|
|
318f2ee003 | ||
|
|
657d59268c | ||
|
|
54b7931779 | ||
|
|
aadba55796 | ||
|
|
1a8d65d7a9 | ||
|
|
36f7d0957a | ||
|
|
85a33ff6aa | ||
|
|
ebef4feddb | ||
|
|
0a34bebd9c | ||
|
|
d6acd96dec | ||
|
|
3ac88acf11 | ||
|
|
2a7cffed2a | ||
|
|
08c2089b3d | ||
|
|
87521b07ce | ||
|
|
5de7efc1dc | ||
|
|
d4fb80772e | ||
|
|
97886bf62e | ||
|
|
83657e0a68 | ||
|
|
adf293f6c9 | ||
|
|
d35120feb9 | ||
|
|
00020f49fe | ||
|
|
206d1b231a | ||
|
|
008d8f491c | ||
|
|
ae5543b6fa | ||
|
|
2c2e6c5c25 | ||
|
|
718bed5758 | ||
|
|
d2c524e7fd | ||
|
|
c172221038 | ||
|
|
b006edf9c4 | ||
|
|
9f2dc4d6d1 | ||
|
|
8816d6c0c5 | ||
|
|
6dddad01fc | ||
|
|
7b7f0a7380 | ||
|
|
8f780e79e7 | ||
|
|
5a2b00fb95 | ||
|
|
25158d4d92 | ||
|
|
3c9e9fa746 | ||
|
|
e5a7e932cf | ||
|
|
84d15b0c7f | ||
|
|
79f216a491 | ||
|
|
b86645c1a7 | ||
|
|
3f359f1059 | ||
|
|
7c7117c39a | ||
|
|
92b8c68b94 | ||
|
|
7b7d722bae | ||
|
|
227c3197b4 | ||
|
|
cab1ba8970 | ||
|
|
7395c94d00 | ||
|
|
2a2dadb952 | ||
|
|
b60dcf8036 | ||
|
|
edeb706d19 | ||
|
|
cd76a2d217 | ||
|
|
2aaa6c3986 | ||
|
|
dfdd56dbec | ||
|
|
f521704ade | ||
|
|
0e0cf9cd3e | ||
|
|
60484bd45e | ||
|
|
5978ba32c0 | ||
|
|
45efa8f40d | ||
|
|
3392e1f0bf | ||
|
|
c734385f18 | ||
|
|
148434d833 | ||
|
|
d5e4a965cc | ||
|
|
c8ac4d9e1b | ||
|
|
1e17461601 | ||
|
|
4432244bc3 | ||
|
|
0d087f84a5 | ||
|
|
286deb277b | ||
|
|
3f2c1139ea | ||
|
|
1f3163942a | ||
|
|
28affcc60a | ||
|
|
f59c2cf7f5 | ||
|
|
86ff52ec36 | ||
|
|
17e7364bdb | ||
|
|
71e4b60a30 | ||
|
|
a3e1ff637d | ||
|
|
a3d2ba1444 | ||
|
|
9baa594a56 | ||
|
|
49de11c3ae | ||
|
|
690fbdc73d | ||
|
|
4d52824895 | ||
|
|
2e13e75fc6 | ||
|
|
f6f6462590 | ||
|
|
79fee16629 | ||
|
|
073f63251a | ||
|
|
20125dc04b | ||
|
|
9f1f8d62f0 | ||
|
|
99d432785c | ||
|
|
54e5401a96 | ||
|
|
6b5d7406d6 | ||
|
|
c6bfeba61a | ||
|
|
af1c8cc7e0 | ||
|
|
879161ed4c | ||
|
|
a15944774c | ||
|
|
42612a4f92 | ||
|
|
5c39935e88 | ||
|
|
35c941c540 | ||
|
|
92931b0d4d | ||
|
|
7c3d2f5578 | ||
|
|
9c809ba147 | ||
|
|
25fb1bb837 | ||
|
|
2d01086a3e | ||
|
|
f8af1e9014 | ||
|
|
2b34a5c646 | ||
|
|
22ca3db870 | ||
|
|
247ca97fbd | ||
|
|
b346b25a7b | ||
|
|
496cf3da4f | ||
|
|
f1b0130389 | ||
|
|
39db3be151 | ||
|
|
517ad7e77c | ||
|
|
5fa10a3f8e | ||
|
|
aa3986e9f2 | ||
|
|
9105c02681 | ||
|
|
6632727d00 | ||
|
|
34ccd5aa86 | ||
|
|
23979bdbee | ||
|
|
dda53292bf | ||
|
|
c98c5f13f7 | ||
|
|
2a69967863 | ||
|
|
781ef806de | ||
|
|
8794c51e42 | ||
|
|
9e89ddf2f1 | ||
|
|
f480a89e7a | ||
|
|
df44fb9827 | ||
|
|
fca9cacc4e | ||
|
|
fa7783972b | ||
|
|
e50b4280f2 | ||
|
|
17c9ccfdb0 | ||
|
|
d35af4c048 | ||
|
|
5fef9bbceb | ||
|
|
507acaa7c9 | ||
|
|
28eb9b62a8 | ||
|
|
061eeb809f | ||
|
|
0db5c6ac8e | ||
|
|
2fd6fd4624 | ||
|
|
db67ae2de4 | ||
|
|
fedcebbe4d | ||
|
|
eaaeb356d7 | ||
|
|
0cddbc2420 | ||
|
|
1b44520ff7 | ||
|
|
ee4dc86c15 | ||
|
|
cbc61daa56 | ||
|
|
f1cd6a06ec | ||
|
|
70248b9684 | ||
|
|
0a5cf8ae42 | ||
|
|
4dc5b18671 | ||
|
|
daffb950c3 | ||
|
|
1a6ebb9606 | ||
|
|
e397efb622 | ||
|
|
748cb0bb8d | ||
|
|
233a449e53 | ||
|
|
09ad8b6238 | ||
|
|
778f703d3d | ||
|
|
8234e4e168 | ||
|
|
08e588a6c4 | ||
|
|
43f550e48b | ||
|
|
5663029ba6 | ||
|
|
0869e23dc1 | ||
|
|
fc4779e095 | ||
|
|
1667603d1c | ||
|
|
81b210cf14 | ||
|
|
ca5fde8ef5 | ||
|
|
fef88734d0 | ||
|
|
aa1727d16f | ||
|
|
7c1afb6493 | ||
|
|
64e7757872 | ||
|
|
3b08250331 | ||
|
|
cd02638db6 | ||
|
|
5109811182 | ||
|
|
33ba8cabd1 | ||
|
|
dcac6028a2 | ||
|
|
25ab435129 | ||
|
|
cea5dc6216 | ||
|
|
bc525d29e1 | ||
|
|
cdfe0ca150 | ||
|
|
bd679e018d | ||
|
|
a6ee18448a | ||
|
|
7d96b3e89e | ||
|
|
e6723f194a | ||
|
|
e55541ea87 | ||
|
|
2676ff8ee3 | ||
|
|
f024fe4488 | ||
|
|
0d3dfb8d0f | ||
|
|
65e1951f5d | ||
|
|
fd6eb91f79 | ||
|
|
2aea0f2ac5 | ||
|
|
e1b5aa7011 | ||
|
|
cdc4d29745 | ||
|
|
5d00792e1f | ||
|
|
1c334f3231 | ||
|
|
9ba3182d19 | ||
|
|
285ad448d7 | ||
|
|
e81fedba43 | ||
|
|
6a47f973b7 | ||
|
|
7aa918cd46 | ||
|
|
183c9dd736 | ||
|
|
5621075cb7 | ||
|
|
b13d2087c2 | ||
|
|
89740af2ab | ||
|
|
4329dfd128 | ||
|
|
bb18a82a9c | ||
|
|
d073fe467d | ||
|
|
e37e885546 | ||
|
|
471ded85f7 | ||
|
|
e084655e69 | ||
|
|
330acb55f4 | ||
|
|
e1ace99e05 | ||
|
|
c0bfa07ea7 | ||
|
|
f4fceac372 | ||
|
|
c1f5345987 | ||
|
|
05992130d0 | ||
|
|
9c646712e0 | ||
|
|
0edff49957 | ||
|
|
0f709cb06a | ||
|
|
656dbbb9f1 | ||
|
|
9497a75c95 | ||
|
|
30c4ed87b5 | ||
|
|
e839765ddc | ||
|
|
055737a6e8 | ||
|
|
fbc609230a | ||
|
|
46f86a54c1 | ||
|
|
70f5231020 | ||
|
|
0ec0feed2c | ||
|
|
60cd505ee1 | ||
|
|
205a719649 | ||
|
|
af7d222a1e | ||
|
|
70f430f635 | ||
|
|
2a92a223f6 | ||
|
|
1ed43614c2 | ||
|
|
3a5295574f | ||
|
|
475b1cb1b8 | ||
|
|
55260a886d | ||
|
|
f77e03ceb0 | ||
|
|
4f7bf5ad58 | ||
|
|
460ea1aa07 | ||
|
|
7aa5f5cb80 | ||
|
|
9cc4133184 | ||
|
|
0d3721324d | ||
|
|
510249d282 | ||
|
|
77d840593b | ||
|
|
8c9a5c4ab5 | ||
|
|
ce497fff27 | ||
|
|
c3cada6bd5 | ||
|
|
bbacfe403c | ||
|
|
159031a071 | ||
|
|
4067927a23 | ||
|
|
d090343083 | ||
|
|
ed130366db | ||
|
|
d4ae40fec4 | ||
|
|
17c864dba3 | ||
|
|
398d6d1efd | ||
|
|
30b46b0f9b | ||
|
|
2d123fa11c | ||
|
|
352a651ac0 | ||
|
|
26b971978d | ||
|
|
478324ea62 | ||
|
|
fa447b2813 | ||
|
|
322790bfdb | ||
|
|
9ea7c18d66 | ||
|
|
2a0d16d57d | ||
|
|
033e2b27a9 | ||
|
|
c7af454bf5 | ||
|
|
d2be9df7a7 | ||
|
|
ae81f4e20a | ||
|
|
261c24b704 | ||
|
|
cc1995170e | ||
|
|
86c785fded | ||
|
|
31c2b1af19 | ||
|
|
90b973f529 | ||
|
|
1164763e1a | ||
|
|
b4319630b1 | ||
|
|
0586c6bdf2 | ||
|
|
c2f60f33e6 | ||
|
|
b118bc959d | ||
|
|
6ef5c2e0c3 | ||
|
|
728fd5b758 | ||
|
|
5bae455e4e | ||
|
|
9387903491 | ||
|
|
12b67bd556 | ||
|
|
9a30bcbd94 | ||
|
|
ab0f096bf2 | ||
|
|
2d8a29d2fd | ||
|
|
8d7de1543b | ||
|
|
b1b41a9b0c | ||
|
|
e9033040f6 | ||
|
|
4388f00607 | ||
|
|
9b8db06349 | ||
|
|
a4f55f6e5d | ||
|
|
d0cde66e92 | ||
|
|
f636c0eb88 | ||
|
|
5760d3180e | ||
|
|
8c4f98131b | ||
|
|
959a433e61 | ||
|
|
f72845a1b4 | ||
|
|
254f4ba574 | ||
|
|
d6f3b1b85f | ||
|
|
47b5b7c4b4 | ||
|
|
bcfdae62e3 | ||
|
|
c60f1c0031 | ||
|
|
d76802f563 | ||
|
|
7b39b31f6c | ||
|
|
a2585a8bb1 | ||
|
|
fe0c4767c7 | ||
|
|
e0b8b82a15 | ||
|
|
f90fa85e77 | ||
|
|
6d5b4e4471 | ||
|
|
1172a33117 | ||
|
|
8530f8ddcc | ||
|
|
e66e4fefed | ||
|
|
1237e839ca | ||
|
|
8ec08063f4 | ||
|
|
0d0004018b | ||
|
|
9124604dc4 | ||
|
|
4f05a7b8d0 | ||
|
|
db766bd9ae | ||
|
|
557603d2ae | ||
|
|
b46ca55c7d | ||
|
|
4788172206 | ||
|
|
6a118f172e | ||
|
|
d5abdfa3b0 | ||
|
|
1d24cb94b4 | ||
|
|
86e7f24238 | ||
|
|
0356f970f3 | ||
|
|
d340038f48 | ||
|
|
1d5e8e07c6 | ||
|
|
24f17479de | ||
|
|
84b8096cc9 | ||
|
|
5c38241e33 | ||
|
|
b590c73c08 | ||
|
|
68fcf9d2df |
@@ -19,8 +19,7 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -83,10 +82,9 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -179,9 +177,9 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoraPatcher.apply_lora_patches(
|
||||
ModelPatcher.apply_lora(
|
||||
text_encoder,
|
||||
patches=_lora_loader(),
|
||||
loras=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
|
||||
@@ -36,8 +36,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
@@ -980,10 +979,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
@@ -29,8 +29,6 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -189,16 +187,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
@@ -252,13 +241,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX LoRA Loader Output"""
|
||||
|
||||
transformer: TransformerField = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
@@ -69,7 +69,6 @@ class CLIPField(BaseModel):
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
@@ -203,7 +202,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@@ -204,11 +204,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoraPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
):
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
|
||||
672
invokeai/backend/lora.py
Normal file
672
invokeai/backend/lora.py
Normal file
@@ -0,0 +1,672 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||
|
||||
# Next, check that this is likely a FLUX model by spot-checking a few keys.
|
||||
expected_keys = [
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
]
|
||||
all_expected_keys_present = all(k in state_dict for k in expected_keys)
|
||||
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||
"""
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
num_single_layers = 38
|
||||
# inner_dim = 3072
|
||||
# mlp_ratio = 4.0
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
layers[dst_key] = LoRALayer(
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
},
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
"""
|
||||
# We expect that either all src keys are present or none of them are. Verify this.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
assert all(keys_present) or not any(keys_present)
|
||||
|
||||
# If none of the keys are present, return early.
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||
sub_layers: list[LoRALayerBase] = []
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
sub_layers.append(
|
||||
LoRALayer(
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
},
|
||||
)
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
|
||||
|
||||
# time_text_embed.text_embedder -> vector_in.
|
||||
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
|
||||
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
|
||||
|
||||
# time_text_embed.guidance_embedder -> guidance_in.
|
||||
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
|
||||
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
|
||||
|
||||
# context_embedder -> txt_in.
|
||||
add_lora_layer_if_present("context_embedder", "txt_in")
|
||||
|
||||
# x_embedder -> img_in.
|
||||
add_lora_layer_if_present("x_embedder", "img_in")
|
||||
|
||||
# Double transformer blocks.
|
||||
for i in range(num_double_layers):
|
||||
# norms.
|
||||
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
|
||||
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
|
||||
|
||||
# Q, K, V
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
f"double_blocks.{i}.img_attn.qkv",
|
||||
)
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
f"double_blocks.{i}.txt_attn.qkv",
|
||||
)
|
||||
|
||||
# ff img_mlp
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff.net.0.proj",
|
||||
f"double_blocks.{i}.img_mlp.0",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff.net.2",
|
||||
f"double_blocks.{i}.img_mlp.2",
|
||||
)
|
||||
|
||||
# ff txt_mlp
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
f"double_blocks.{i}.txt_mlp.0",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff_context.net.2",
|
||||
f"double_blocks.{i}.txt_mlp.2",
|
||||
)
|
||||
|
||||
# output projections.
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.attn.to_out.0",
|
||||
f"double_blocks.{i}.img_attn.proj",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.attn.to_add_out",
|
||||
f"double_blocks.{i}.txt_attn.proj",
|
||||
)
|
||||
|
||||
# Single transformer blocks.
|
||||
for i in range(num_single_layers):
|
||||
# norms
|
||||
add_lora_layer_if_present(
|
||||
f"single_transformer_blocks.{i}.norm.linear",
|
||||
f"single_blocks.{i}.modulation.lin",
|
||||
)
|
||||
|
||||
# Q, K, V, mlp
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"single_transformer_blocks.{i}.attn.to_q",
|
||||
f"single_transformer_blocks.{i}.attn.to_k",
|
||||
f"single_transformer_blocks.{i}.attn.to_v",
|
||||
f"single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
add_lora_layer_if_present(
|
||||
f"single_transformer_blocks.{i}.proj_out",
|
||||
f"single_blocks.{i}.linear2",
|
||||
)
|
||||
|
||||
# Final layer.
|
||||
add_lora_layer_if_present("proj_out", "final_layer.linear")
|
||||
|
||||
# Assert that all keys were processed.
|
||||
assert len(grouped_state_dict) == 0
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Groups the keys in the state dict by layer."""
|
||||
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key in state_dict:
|
||||
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
|
||||
parts = key.rsplit(".", maxsplit=2)
|
||||
layer_name = parts[0]
|
||||
key_name = ".".join(parts[1:])
|
||||
if layer_name not in layer_dict:
|
||||
layer_dict[layer_name] = {}
|
||||
layer_dict[layer_name][key_name] = state_dict[key]
|
||||
return layer_dict
|
||||
@@ -1,80 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
if match.group(4):
|
||||
s += f".{match.group(4)}"
|
||||
return s
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
@@ -1,154 +0,0 @@
|
||||
import bisect
|
||||
from typing import Dict, List, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
|
||||
|
||||
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
|
||||
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
|
||||
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
|
||||
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
|
||||
super().__init__(values={})
|
||||
|
||||
self._lora_layers = lora_layers
|
||||
self._concat_axis = concat_axis
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original weight tensor here.
|
||||
layer_weights = [lora_layer.get_weight(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||
return torch.cat(layer_weights, dim=self._concat_axis)
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original bias tensor here.
|
||||
layer_biases = [lora_layer.get_bias(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||
layer_bias_is_none = [layer_bias is None for layer_bias in layer_biases]
|
||||
if any(layer_bias_is_none):
|
||||
assert all(layer_bias_is_none)
|
||||
return None
|
||||
|
||||
# Ignore the type error, because we have just verified that all layer biases are non-None.
|
||||
return torch.cat(layer_biases, dim=self._concat_axis)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return sum(lora_layer.calc_size() for lora_layer in self._lora_layers)
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@@ -1,41 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
@@ -1,68 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
@@ -1,113 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
@@ -1,58 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@@ -1,33 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
if "lora_up.weight" in state_dict:
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer(state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
return LoHALayer(state_dict)
|
||||
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||
return LoKRLayer(state_dict)
|
||||
elif "diff" in state_dict:
|
||||
# Full a.k.a Diff
|
||||
return FullLayer(state_dict)
|
||||
elif "on_input" in state_dict:
|
||||
return IA3Layer(state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer(state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
@@ -1,22 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def __init__(self, layers: Dict[str, AnyLoRALayer]):
|
||||
self.layers = layers
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
@@ -1,148 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoraPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
|
||||
that the LoRA patches do not need to be loaded into memory all at once.
|
||||
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoraPatcher.apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply a single LoRA patch to a model.
|
||||
:param model: The model to patch.
|
||||
:param patch: LoRA model to patch in.
|
||||
:param patch_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoraPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
) -> tuple[str, torch.nn.Module]:
|
||||
"""Get the submodule corresponding to the given layer key.
|
||||
:param model: The model to search.
|
||||
:param layer_key: The layer key to search for.
|
||||
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
|
||||
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
|
||||
searching, but some legacy code still uses flattened keys.
|
||||
:return: A tuple containing the module key and the submodule.
|
||||
"""
|
||||
if not layer_key_is_flattened:
|
||||
return layer_key, model.get_submodule(layer_key)
|
||||
|
||||
# Handle flattened keys.
|
||||
assert "." not in layer_key
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = layer_key.split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return module_key, module
|
||||
@@ -5,18 +5,8 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -55,33 +45,14 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
|
||||
# Load the state dict from the model file.
|
||||
if model_path.suffix == ".safetensors":
|
||||
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||
|
||||
model.to(dtype=self._torch_dtype)
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
@@ -15,7 +15,7 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
@@ -10,10 +10,6 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@@ -248,9 +244,7 @@ class ModelProbe(object):
|
||||
return ModelType.VAE
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
return ModelType.LoRA
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
@@ -560,21 +554,12 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for LoRA checkpoints."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
|
||||
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
|
||||
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
|
||||
# file, but the weight keys are in the diffusers format.
|
||||
return ModelFormat.Diffusers
|
||||
return ModelFormat.LyCORIS
|
||||
return ModelFormat("lycoris")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model.
|
||||
token_vector_length = lora_token_vector_length(self.checkpoint)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
|
||||
@@ -5,18 +5,32 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
@@ -40,6 +54,95 @@ class ModelPatcher:
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
cached_weights=cached_weights,
|
||||
):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for lora_model, lora_weight in loras:
|
||||
LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
@@ -179,6 +282,26 @@ class ModelPatcher:
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
@@ -30,14 +31,107 @@ class LoRAExt(ExtensionBase):
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoraPatcher.apply_lora_patch(
|
||||
self.patch_model(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
patch_weight=self._weight,
|
||||
lora=lora_model,
|
||||
lora_weight=self._weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if lora_weight == 0:
|
||||
return
|
||||
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= lora_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@@ -16,14 +16,6 @@ module.exports = {
|
||||
'no-promise-executor-return': 'error',
|
||||
// https://eslint.org/docs/latest/rules/require-await
|
||||
'require-await': 'error',
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
object: 'crypto',
|
||||
property: 'randomUUID',
|
||||
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||
},
|
||||
],
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
|
||||
@@ -58,7 +58,7 @@
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.0.20",
|
||||
"@invoke-ai/ui-library": "^0.0.33",
|
||||
"@invoke-ai/ui-library": "^0.0.32",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
@@ -92,7 +92,7 @@
|
||||
"react-i18next": "^14.1.3",
|
||||
"react-icons": "^5.2.1",
|
||||
"react-redux": "9.1.2",
|
||||
"react-resizable-panels": "^2.1.2",
|
||||
"react-resizable-panels": "^2.0.23",
|
||||
"react-use": "^17.5.1",
|
||||
"react-virtuoso": "^4.9.0",
|
||||
"reactflow": "^11.11.4",
|
||||
|
||||
62
invokeai/frontend/web/pnpm-lock.yaml
generated
62
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -24,8 +24,8 @@ dependencies:
|
||||
specifier: ^5.0.20
|
||||
version: 5.0.20
|
||||
'@invoke-ai/ui-library':
|
||||
specifier: ^0.0.33
|
||||
version: 0.0.33(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^0.0.32
|
||||
version: 0.0.32(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@nanostores/react':
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3(nanostores@0.11.2)(react@18.3.1)
|
||||
@@ -126,8 +126,8 @@ dependencies:
|
||||
specifier: 9.1.2
|
||||
version: 9.1.2(@types/react@18.3.3)(react@18.3.1)(redux@5.0.1)
|
||||
react-resizable-panels:
|
||||
specifier: ^2.1.2
|
||||
version: 2.1.2(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^2.0.23
|
||||
version: 2.0.23(react-dom@18.3.1)(react@18.3.1)
|
||||
react-use:
|
||||
specifier: ^17.5.1
|
||||
version: 17.5.1(react-dom@18.3.1)(react@18.3.1)
|
||||
@@ -2056,7 +2056,7 @@ packages:
|
||||
dependencies:
|
||||
'@chakra-ui/dom-utils': 2.1.0
|
||||
react: 18.3.1
|
||||
react-focus-lock: 2.13.2(@types/react@18.3.3)(react@18.3.1)
|
||||
react-focus-lock: 2.13.0(@types/react@18.3.3)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
@@ -2253,7 +2253,7 @@ packages:
|
||||
framer-motion: 10.18.0(react-dom@18.3.1)(react@18.3.1)
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
react-remove-scroll: 2.6.0(@types/react@18.3.3)(react@18.3.1)
|
||||
react-remove-scroll: 2.5.10(@types/react@18.3.3)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
@@ -3574,8 +3574,8 @@ packages:
|
||||
prettier: 3.3.3
|
||||
dev: true
|
||||
|
||||
/@invoke-ai/ui-library@0.0.33(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-YLydTCOTUEgju4Ex6yXt/bvNBcO97y6zc1cGYjt7vtJMS8e6deA89cC5JejjbmVgntdnn49cDyeUxB8Z24gZew==}
|
||||
/@invoke-ai/ui-library@0.0.32(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-JxAoblrDu/cZ4ha9KO4ry5OWvyLUE1Dj28i+ciMaDNUpC/cN+IyiTbUBoFoPaoN5JP8Zpd/MYCcmF2qsziHDzg==}
|
||||
peerDependencies:
|
||||
'@fontsource-variable/inter': ^5.0.16
|
||||
react: ^18.2.0
|
||||
@@ -10011,8 +10011,8 @@ packages:
|
||||
resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==}
|
||||
dev: false
|
||||
|
||||
/react-focus-lock@2.13.2(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-T/7bsofxYqnod2xadvuwjGKHOoL5GH7/EIPI5UyEvaU/c2CcphvGI371opFtuY/SYdbMsNiuF4HsHQ50nA/TKQ==}
|
||||
/react-focus-lock@2.13.0(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-w7aIcTwZwNzUp2fYQDMICy+khFwVmKmOrLF8kNsPS+dz4Oi/oxoVJ2wCMVvX6rWGriM/+mYaTyp1MRmkcs2amw==}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
@@ -10147,6 +10147,25 @@ packages:
|
||||
tslib: 2.7.0
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.5.10(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-m3zvBRANPBw3qxVVjEIPEQinkcwlFZ4qyomuWVpNJdv4c6MvHfXV0C3L9Jx5rr3HeBHKNRX+1jreB5QloDIJjA==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
react-remove-scroll-bar: 2.3.6(@types/react@18.3.3)(react@18.3.1)
|
||||
react-style-singleton: 2.2.1(@types/react@18.3.3)(react@18.3.1)
|
||||
tslib: 2.7.0
|
||||
use-callback-ref: 1.3.2(@types/react@18.3.3)(react@18.3.1)
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.5.5(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==}
|
||||
engines: {node: '>=10'}
|
||||
@@ -10166,27 +10185,8 @@ packages:
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.6.0(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-I2U4JVEsQenxDAKaVa3VZ/JeJZe0/2DxPWL8Tj8yLKctQJQiZM52pn/GWFpSp8dftjM3pSAHVJZscAnC/y+ySQ==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
react-remove-scroll-bar: 2.3.6(@types/react@18.3.3)(react@18.3.1)
|
||||
react-style-singleton: 2.2.1(@types/react@18.3.3)(react@18.3.1)
|
||||
tslib: 2.7.0
|
||||
use-callback-ref: 1.3.2(@types/react@18.3.3)(react@18.3.1)
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-resizable-panels@2.1.2(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-Ku2Bo7JvE8RpHhl4X1uhkdeT9auPBoxAOlGTqomDUUrBAX2mVGuHYZTcWvlnJSgx0QyHIxHECgGB5XVPUbUOkQ==}
|
||||
/react-resizable-panels@2.0.23(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-8ZKTwTU11t/FYwiwhMdtZYYyFxic5U5ysRu2YwfkAgDbUJXFvnWSJqhnzkSlW+mnDoNAzDCrJhdOSXBPA76wug==}
|
||||
peerDependencies:
|
||||
react: ^16.14.0 || ^17.0.0 || ^18.0.0
|
||||
react-dom: ^16.14.0 || ^17.0.0 || ^18.0.0
|
||||
|
||||
@@ -93,7 +93,6 @@
|
||||
"copy": "Copy",
|
||||
"copyError": "$t(gallery.copy) Error",
|
||||
"on": "On",
|
||||
"off": "Off",
|
||||
"or": "or",
|
||||
"checkpoint": "Checkpoint",
|
||||
"communityLabel": "Community",
|
||||
@@ -375,7 +374,6 @@
|
||||
"useCache": "Use Cache"
|
||||
},
|
||||
"gallery": {
|
||||
"gallery": "Gallery",
|
||||
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
|
||||
"assets": "Assets",
|
||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
|
||||
@@ -388,11 +386,11 @@
|
||||
"deleteImage_one": "Delete Image",
|
||||
"deleteImage_other": "Delete {{count}} Images",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
"displayBoardSearch": "Board Search",
|
||||
"displaySearch": "Image Search",
|
||||
"displayBoardSearch": "Display Board Search",
|
||||
"displaySearch": "Display Search",
|
||||
"download": "Download",
|
||||
"exitBoardSearch": "Exit Board Search",
|
||||
"exitSearch": "Exit Image Search",
|
||||
"exitSearch": "Exit Search",
|
||||
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
|
||||
"galleryImageSize": "Image Size",
|
||||
"gallerySettings": "Gallery Settings",
|
||||
@@ -438,8 +436,7 @@
|
||||
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
|
||||
"toggleMiniViewer": "Toggle Mini Viewer"
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Search Hotkeys",
|
||||
@@ -1051,8 +1048,8 @@
|
||||
"scaledHeight": "Scaled H",
|
||||
"scaledWidth": "Scaled W",
|
||||
"scheduler": "Scheduler",
|
||||
"seamlessXAxis": "Seamless X Axis",
|
||||
"seamlessYAxis": "Seamless Y Axis",
|
||||
"seamlessXAxis": "Seamless Tiling X Axis",
|
||||
"seamlessYAxis": "Seamless Tiling Y Axis",
|
||||
"seed": "Seed",
|
||||
"imageActions": "Image Actions",
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
@@ -1716,8 +1713,6 @@
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"sendingToCanvas": "Sending to Canvas",
|
||||
"sendingToGallery": "Sending to Gallery",
|
||||
"sendToGallery": "Send To Gallery",
|
||||
"sendToGalleryDesc": "Generations will be sent to the gallery.",
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
@@ -1797,19 +1792,9 @@
|
||||
"filter": "Filter",
|
||||
"filters": "Filters",
|
||||
"filterType": "Filter Type",
|
||||
"autoProcess": "Auto Process",
|
||||
"reset": "Reset",
|
||||
"process": "Process",
|
||||
"preview": "Preview",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel",
|
||||
"spandrel": {
|
||||
"label": "Image-to-Image Model",
|
||||
"description": "Run an image-to-image model on the selected layer.",
|
||||
"paramModel": "Model",
|
||||
"paramAutoScale": "Auto Scale",
|
||||
"paramAutoScaleDesc": "The selected model will be run until the target scale is reached.",
|
||||
"paramScale": "Target Scale"
|
||||
}
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"transform": {
|
||||
"transform": "Transform",
|
||||
@@ -1817,28 +1802,6 @@
|
||||
"reset": "Reset",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"settings": {
|
||||
"snapToGrid": {
|
||||
"label": "Snap to Grid",
|
||||
"on": "On",
|
||||
"off": "Off"
|
||||
}
|
||||
},
|
||||
"HUD": {
|
||||
"bbox": "Bbox",
|
||||
"scaledBbox": "Scaled Bbox",
|
||||
"autoSave": "Auto Save",
|
||||
"entityStatus": {
|
||||
"selectedEntity": "Selected Entity",
|
||||
"selectedEntityIs": "Selected Entity is",
|
||||
"isFiltering": "is filtering",
|
||||
"isTransforming": "is transforming",
|
||||
"isLocked": "is locked",
|
||||
"isHidden": "is hidden",
|
||||
"isDisabled": "is disabled",
|
||||
"enabled": "Enabled"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
|
||||
@@ -21,16 +21,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
direction,
|
||||
shadows: {
|
||||
..._theme.shadows,
|
||||
selected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverSelected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverUnselected:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
selectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
||||
hoverSelectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
||||
},
|
||||
});
|
||||
}, [direction]);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
@@ -9,7 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
@@ -41,8 +40,6 @@ export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
||||
|
||||
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
@@ -122,5 +119,3 @@ addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
// addControlAdapterPreprocessor(startAppListening);
|
||||
|
||||
addCancellationsListeners(startAppListening);
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import {
|
||||
sessionStagingAreaImageAccepted,
|
||||
sessionStagingAreaReset,
|
||||
} from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { canvasReset, rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { stagingAreaImageAccepted, stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
const matchCanvasOrStagingAreaRest = isAnyOf(stagingAreaReset, canvasReset);
|
||||
const matchCanvasOrStagingAreaRest = isAnyOf(sessionStagingAreaReset, canvasReset);
|
||||
|
||||
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -29,6 +33,8 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
|
||||
if (canceled > 0) {
|
||||
log.debug(`Canceled ${canceled} canvas batches`);
|
||||
toast({
|
||||
@@ -49,11 +55,11 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: stagingAreaImageAccepted,
|
||||
actionCreator: sessionStagingAreaImageAccepted,
|
||||
effect: (action, api) => {
|
||||
const { index } = action.payload;
|
||||
const state = api.getState();
|
||||
const stagingAreaImage = state.canvasStagingArea.stagedImages[index];
|
||||
const stagingAreaImage = state.canvasSession.stagedImages[index];
|
||||
|
||||
assert(stagingAreaImage, 'No staged image found to accept');
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
@@ -66,7 +72,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
};
|
||||
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: false }));
|
||||
api.dispatch(stagingAreaReset());
|
||||
api.dispatch(sessionStagingAreaReset());
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
/**
|
||||
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
|
||||
* cancellations:
|
||||
* - In the route handlers above, we track and update the cancellations object
|
||||
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
|
||||
* - When we get a progress event, we should check if the event is cancelled before setting the event
|
||||
*
|
||||
* We have a few ways that cancellations are effected, so we need to track them all:
|
||||
* - by queue item id (in this case, we will compare the session_id and not the item_id)
|
||||
* - by batch id
|
||||
* - by destination
|
||||
* - by clearing the queue
|
||||
*/
|
||||
type Cancellations = {
|
||||
sessionIds: Set<string>;
|
||||
batchIds: Set<string>;
|
||||
destinations: Set<string>;
|
||||
clearQueue: boolean;
|
||||
};
|
||||
|
||||
const resetCancellations = (): void => {
|
||||
cancellations.clearQueue = false;
|
||||
cancellations.sessionIds.clear();
|
||||
cancellations.batchIds.clear();
|
||||
cancellations.destinations.clear();
|
||||
};
|
||||
|
||||
const cancellations: Cancellations = {
|
||||
sessionIds: new Set(),
|
||||
batchIds: new Set(),
|
||||
destinations: new Set(),
|
||||
clearQueue: false,
|
||||
} as Readonly<Cancellations>;
|
||||
|
||||
/**
|
||||
* Checks if an item is cancelled, used to prevent race conditions with event handling.
|
||||
*
|
||||
* To use this, provide the session_id, batch_id and destination from the event payload.
|
||||
*/
|
||||
export const getIsCancelled = (item: {
|
||||
session_id: string;
|
||||
batch_id: string;
|
||||
destination?: string | null;
|
||||
}): boolean => {
|
||||
if (cancellations.clearQueue) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.sessionIds.has(item.session_id)) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.batchIds.has(item.batch_id)) {
|
||||
return true;
|
||||
}
|
||||
if (item.destination && cancellations.destinations.has(item.destination)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
|
||||
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
|
||||
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: () => {
|
||||
resetCancellations();
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.sessionIds.add(action.payload.session_id);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
|
||||
effect: (action) => {
|
||||
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
|
||||
cancellations.batchIds.add(batch_id);
|
||||
}
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
|
||||
effect: () => {
|
||||
cancellations.clearQueue = true;
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -4,12 +4,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import type { Result } from 'common/util/result';
|
||||
import { isErr, withResult, withResultAsync } from 'common/util/result';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
selectIsStaging,
|
||||
stagingAreaReset,
|
||||
stagingAreaStartedStaging,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
@@ -35,14 +31,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
let didStartStaging = false;
|
||||
|
||||
if (!selectIsStaging(state) && state.canvasSettings.sendToCanvas) {
|
||||
dispatch(stagingAreaStartedStaging());
|
||||
if (!state.canvasSession.isStaging && state.canvasSettings.sendToCanvas) {
|
||||
dispatch(sessionStartedStaging());
|
||||
didStartStaging = true;
|
||||
}
|
||||
|
||||
const abortStaging = () => {
|
||||
if (didStartStaging && selectIsStaging(getState())) {
|
||||
dispatch(stagingAreaReset());
|
||||
if (didStartStaging && getState().canvasSession.isStaging) {
|
||||
dispatch(sessionStagingAreaReset());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
@@ -10,6 +11,7 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { shouldShowProgressInViewer } = state.ui;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||
@@ -23,6 +25,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
);
|
||||
try {
|
||||
await req.unwrap();
|
||||
if (shouldShowProgressInViewer) {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
//TODO(psyche): handle image deletion (canvas staging area?)
|
||||
//TODO(psyche): handle image deletion (canvas sessions?)
|
||||
|
||||
// Some utils to delete images from different parts of the app
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
ipaImageChanged,
|
||||
@@ -13,7 +12,7 @@ import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/c
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -104,14 +103,11 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
const defaultControlAdapter = selectDefaultControlAdapter(state);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
controlAdapter: defaultControlAdapter,
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
@@ -146,6 +142,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -45,8 +44,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
if (!autoAddBoardId || autoAddBoardId === 'none') {
|
||||
const title = postUploadAction.title || DEFAULT_UPLOADED_TOAST.title;
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, title });
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
} else {
|
||||
// Add this image to the board
|
||||
dispatch(
|
||||
@@ -70,8 +67,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description,
|
||||
});
|
||||
dispatch(boardIdSelected({ boardId: autoAddBoardId }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,12 +6,9 @@ import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
canvasStagingAreaPersistConfig,
|
||||
canvasStagingAreaSlice,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
|
||||
@@ -66,7 +63,7 @@ const allReducers = {
|
||||
[stylePresetSlice.name]: stylePresetSlice.reducer,
|
||||
[paramsSlice.name]: paramsSlice.reducer,
|
||||
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
|
||||
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
|
||||
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
};
|
||||
|
||||
@@ -111,7 +108,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
|
||||
[paramsPersistConfig.name]: paramsPersistConfig,
|
||||
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
|
||||
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
|
||||
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
};
|
||||
|
||||
|
||||
@@ -6,51 +6,18 @@ import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
|
||||
import IAIDraggable from './IAIDraggable';
|
||||
import IAIDroppable from './IAIDroppable';
|
||||
import SelectionOverlay from './SelectionOverlay';
|
||||
|
||||
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
|
||||
|
||||
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
'.gallery-image-container::before': {
|
||||
content: '""',
|
||||
display: 'inline-block',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
pointerEvents: 'none',
|
||||
borderRadius: 'base',
|
||||
},
|
||||
'&[data-selected="selected"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&[data-selected="selectedForCompare"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
'&:hover>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected="selected"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected="selectedForCompare"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
};
|
||||
|
||||
type IAIDndImageProps = FlexProps & {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
@@ -108,11 +75,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const handleMouseOver = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOver) {
|
||||
onMouseOver(e);
|
||||
}
|
||||
setIsHovered(true);
|
||||
},
|
||||
[onMouseOver]
|
||||
);
|
||||
@@ -121,6 +90,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
if (onMouseOut) {
|
||||
onMouseOut(e);
|
||||
}
|
||||
setIsHovered(false);
|
||||
},
|
||||
[onMouseOut]
|
||||
);
|
||||
@@ -171,13 +141,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
minH={minSize ? minSize : undefined}
|
||||
userSelect="none"
|
||||
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
|
||||
sx={withHoverOverlay ? sx : undefined}
|
||||
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
|
||||
{...rest}
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
className="gallery-image-container"
|
||||
w="full"
|
||||
h="full"
|
||||
position={fitContainer ? 'absolute' : 'relative'}
|
||||
@@ -200,6 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
<SelectionOverlay
|
||||
isSelected={isSelected}
|
||||
isSelectedForCompare={isSelectedForCompare}
|
||||
isHovered={withHoverOverlay ? isHovered : false}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
|
||||
104
invokeai/frontend/web/src/common/components/IconSwitch.tsx
Normal file
104
invokeai/frontend/web/src/common/components/IconSwitch.tsx
Normal file
@@ -0,0 +1,104 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Tooltip, useToken } from '@invoke-ai/ui-library';
|
||||
import type { ReactElement, ReactNode } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type IconSwitchProps = {
|
||||
isChecked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
iconChecked: ReactElement;
|
||||
tooltipChecked?: ReactNode;
|
||||
iconUnchecked: ReactElement;
|
||||
tooltipUnchecked?: ReactNode;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
const getSx = (padding: string | number): SystemStyleObject => ({
|
||||
transition: 'left 0.1s ease-in-out, transform 0.1s ease-in-out',
|
||||
'&[data-checked="true"]': {
|
||||
left: `calc(100% - ${padding})`,
|
||||
transform: 'translateX(-100%)',
|
||||
},
|
||||
'&[data-checked="false"]': {
|
||||
left: padding,
|
||||
transform: 'translateX(0)',
|
||||
},
|
||||
});
|
||||
|
||||
export const IconSwitch = memo(
|
||||
({
|
||||
isChecked,
|
||||
onChange,
|
||||
iconChecked,
|
||||
tooltipChecked,
|
||||
iconUnchecked,
|
||||
tooltipUnchecked,
|
||||
ariaLabel,
|
||||
}: IconSwitchProps) => {
|
||||
const onUncheck = useCallback(() => {
|
||||
onChange(false);
|
||||
}, [onChange]);
|
||||
const onCheck = useCallback(() => {
|
||||
onChange(true);
|
||||
}, [onChange]);
|
||||
|
||||
const gap = useToken('space', 1.5);
|
||||
const sx = useMemo(() => getSx(gap), [gap]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
h="full"
|
||||
p={gap}
|
||||
gap={gap}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
borderRadius="base"
|
||||
bg="invokeBlue.400"
|
||||
w={12}
|
||||
top={gap}
|
||||
bottom={gap}
|
||||
data-checked={isChecked}
|
||||
sx={sx}
|
||||
/>
|
||||
<Tooltip hasArrow label={tooltipUnchecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconUnchecked}
|
||||
onClick={onUncheck}
|
||||
variant={!isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={!isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={!isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip hasArrow label={tooltipChecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconChecked}
|
||||
onClick={onCheck}
|
||||
variant={isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IconSwitch.displayName = 'IconSwitch';
|
||||
@@ -0,0 +1,46 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isSelected: boolean;
|
||||
isSelectedForCompare: boolean;
|
||||
isHovered: boolean;
|
||||
};
|
||||
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
|
||||
const shadow = useMemo(() => {
|
||||
if (isSelectedForCompare && isHovered) {
|
||||
return 'hoverSelectedForCompare';
|
||||
}
|
||||
if (isSelectedForCompare && !isHovered) {
|
||||
return 'selectedForCompare';
|
||||
}
|
||||
if (isSelected && isHovered) {
|
||||
return 'hoverSelected';
|
||||
}
|
||||
if (isSelected && !isHovered) {
|
||||
return 'selected';
|
||||
}
|
||||
if (!isSelected && isHovered) {
|
||||
return 'hoverUnselected';
|
||||
}
|
||||
return undefined;
|
||||
}, [isHovered, isSelected, isSelectedForCompare]);
|
||||
return (
|
||||
<Box
|
||||
className="selection-box"
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineEnd={0}
|
||||
bottom={0}
|
||||
insetInlineStart={0}
|
||||
borderRadius="base"
|
||||
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
pointerEvents="none"
|
||||
shadow={shadow}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SelectionOverlay);
|
||||
@@ -114,13 +114,4 @@ export const useGlobalHotkeys = () => {
|
||||
},
|
||||
[dispatch, isModelManagerEnabled]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
isModelManagerEnabled ? '6' : '5',
|
||||
() => {
|
||||
dispatch(setActiveTab('gallery'));
|
||||
setScopes([]);
|
||||
},
|
||||
[dispatch, isModelManagerEnabled]
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,22 +1,34 @@
|
||||
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { memo } from 'react';
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasAddEntityButtons = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const addInpaintMask = useAddInpaintMask();
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
const dispatch = useAppDispatch();
|
||||
const addInpaintMask = useCallback(() => {
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRegionalGuidance = useCallback(() => {
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRasterLayer = useCallback(() => {
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addControlLayer = useCallback(() => {
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(ipaAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center">
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type { AddControlLayerFromImageDropData, AddRasterLayerFromImageDropData } from 'features/dnd/types';
|
||||
import { useIsImageViewerOpen } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo } from 'react';
|
||||
|
||||
const addRasterLayerFromImageDropData: AddRasterLayerFromImageDropData = {
|
||||
@@ -14,6 +15,12 @@ const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
|
||||
};
|
||||
|
||||
export const CanvasDropArea = memo(() => {
|
||||
const isImageViewerOpen = useIsImageViewerOpen();
|
||||
|
||||
if (isImageViewerOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex position="absolute" top={0} right={0} bottom="50%" left={0} gap={2} pointerEvents="none">
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import type { Meta, StoryObj } from '@storybook/react';
|
||||
import { CanvasEditor } from 'features/controlLayers/components/CanvasEditor';
|
||||
|
||||
const meta: Meta<typeof CanvasEditor> = {
|
||||
title: 'Feature/ControlLayers',
|
||||
tags: ['autodocs'],
|
||||
component: CanvasEditor,
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof CanvasEditor>;
|
||||
|
||||
const Component = () => {
|
||||
return (
|
||||
<Flex w={1500} h={1500}>
|
||||
<CanvasEditor />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
render: Component,
|
||||
};
|
||||
@@ -11,7 +11,7 @@ import { Transform } from 'features/controlLayers/components/Transform/Transform
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo, useRef } from 'react';
|
||||
|
||||
export const CanvasTabContent = memo(() => {
|
||||
export const CanvasEditor = memo(() => {
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useScopeOnFocus('canvas', ref);
|
||||
|
||||
@@ -48,4 +48,4 @@ export const CanvasTabContent = memo(() => {
|
||||
);
|
||||
});
|
||||
|
||||
CanvasTabContent.displayName = 'CanvasTabContent';
|
||||
CanvasEditor.displayName = 'CanvasEditor';
|
||||
@@ -0,0 +1,20 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { EntityListGlobalActionBarAddLayerMenu } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu';
|
||||
import { EntityListGlobalActionBarDenoisingStrength } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarDenoisingStrength';
|
||||
import { EntityListGlobalActionBarFitBboxToLayers } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarFitBboxToLayers';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EntityListGlobalActionBar = memo(() => {
|
||||
return (
|
||||
<Flex w="full" py={1} px={1} gap={2} alignItems="center">
|
||||
<EntityListGlobalActionBarDenoisingStrength />
|
||||
<Spacer />
|
||||
<Flex>
|
||||
<EntityListGlobalActionBarFitBboxToLayers />
|
||||
<EntityListGlobalActionBarAddLayerMenu />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListGlobalActionBar.displayName = 'EntityListGlobalActionBar';
|
||||
@@ -1,22 +1,37 @@
|
||||
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useDefaultIPAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { memo } from 'react';
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const addInpaintMask = useAddInpaintMask();
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useDefaultIPAdapter();
|
||||
const addInpaintMask = useCallback(() => {
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRegionalGuidance = useCallback(() => {
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRasterLayer = useCallback(() => {
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addControlLayer = useCallback(() => {
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
const overrides = { ipAdapter: defaultIPAdapter };
|
||||
dispatch(ipaAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import {
|
||||
CompositeSlider,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
IconButton,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
Popover,
|
||||
PopoverAnchor,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import type { KeyboardEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
|
||||
const marks = [0, 0.25, 0.5, 0.75, 1];
|
||||
|
||||
export const EntityListGlobalActionBarDenoisingStrength = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const strength = useAppSelector(selectImg2imgStrength);
|
||||
const config = useAppSelector(selectImg2imgStrengthConfig);
|
||||
|
||||
const [localStrength, setLocalStrength] = useState(strength);
|
||||
|
||||
const onChangeSlider = useCallback(
|
||||
(value: number) => {
|
||||
dispatch(setImg2imgStrength(value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onBlur = useCallback(() => {
|
||||
if (isNaN(Number(localStrength))) {
|
||||
setLocalStrength(config.initial);
|
||||
return;
|
||||
}
|
||||
dispatch(setImg2imgStrength(clamp(localStrength, 0, 1)));
|
||||
}, [config.initial, dispatch, localStrength]);
|
||||
|
||||
const onChangeNumberInput = useCallback((valueAsString: string, valueAsNumber: number) => {
|
||||
setLocalStrength(valueAsNumber);
|
||||
}, []);
|
||||
|
||||
const onKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
onBlur();
|
||||
}
|
||||
},
|
||||
[onBlur]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalStrength(strength);
|
||||
}, [strength]);
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<FormControl w="min-content" gap={2}>
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel m={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
</InformationalPopover>
|
||||
<PopoverAnchor>
|
||||
<NumberInput
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
step={config.coarseStep}
|
||||
min={config.numberInputMin}
|
||||
max={config.numberInputMax}
|
||||
defaultValue={config.initial}
|
||||
value={localStrength}
|
||||
onChange={onChangeNumberInput}
|
||||
onBlur={onBlur}
|
||||
w="60px"
|
||||
onKeyDown={onKeyDown}
|
||||
clampValueOnBlur={false}
|
||||
variant="outline"
|
||||
>
|
||||
<NumberInputField paddingInlineEnd={7} _focusVisible={{ zIndex: 0 }} />
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
aria-label="open-slider"
|
||||
icon={<PiCaretDownBold />}
|
||||
size="sm"
|
||||
variant="link"
|
||||
position="absolute"
|
||||
insetInlineEnd={0}
|
||||
h="full"
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
</NumberInput>
|
||||
</PopoverAnchor>
|
||||
</FormControl>
|
||||
<PopoverContent w={200} pt={0} pb={2} px={4}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
fineStep={config.fineStep}
|
||||
min={config.sliderMin}
|
||||
max={config.sliderMax}
|
||||
defaultValue={config.initial}
|
||||
onChange={onChangeSlider}
|
||||
value={localStrength}
|
||||
marks={marks}
|
||||
alwaysShowMarks
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListGlobalActionBarDenoisingStrength.displayName = 'EntityListGlobalActionBarDenoisingStrength';
|
||||
@@ -4,7 +4,7 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsOut } from 'react-icons/pi';
|
||||
|
||||
export const CanvasToolbarFitBboxToLayersButton = memo(() => {
|
||||
export const EntityListGlobalActionBarFitBboxToLayers = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const canvasManager = useCanvasManager();
|
||||
const onClick = useCallback(() => {
|
||||
@@ -14,7 +14,9 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('controlLayers.fitBboxToLayers')}
|
||||
tooltip={t('controlLayers.fitBboxToLayers')}
|
||||
icon={<PiArrowsOut />}
|
||||
@@ -22,4 +24,4 @@ export const CanvasToolbarFitBboxToLayersButton = memo(() => {
|
||||
);
|
||||
});
|
||||
|
||||
CanvasToolbarFitBboxToLayersButton.displayName = 'CanvasToolbarFitBboxToLayersButton';
|
||||
EntityListGlobalActionBarFitBboxToLayers.displayName = 'EntityListGlobalActionBarFitBboxToLayers';
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { EntityListGlobalActionBarAddLayerMenu } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu';
|
||||
import { EntityListSelectedEntityActionBarDuplicateButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarDuplicateButton';
|
||||
import { EntityListSelectedEntityActionBarFill } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFill';
|
||||
import { EntityListSelectedEntityActionBarFilterButton } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBarFilterButton';
|
||||
@@ -9,15 +8,14 @@ import { memo } from 'react';
|
||||
|
||||
export const EntityListSelectedEntityActionBar = memo(() => {
|
||||
return (
|
||||
<Flex w="full" gap={2} alignItems="center" ps={1}>
|
||||
<Flex w="full" py={1} px={1} gap={2} alignItems="center">
|
||||
<EntityListSelectedEntityActionBarOpacity />
|
||||
<Spacer />
|
||||
<EntityListSelectedEntityActionBarFill />
|
||||
<Flex h="full">
|
||||
<Flex>
|
||||
<EntityListSelectedEntityActionBarFilterButton />
|
||||
<EntityListSelectedEntityActionBarTransformButton />
|
||||
<EntityListSelectedEntityActionBarDuplicateButton />
|
||||
<EntityListGlobalActionBarAddLayerMenu />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityFilter } from 'features/controlLayers/hooks/useEntityFilter';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiShootingStarBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const filter = useEntityFilter(selectedEntityIdentifier);
|
||||
const canvasManager = useCanvasManager();
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
if (!selectedEntityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isFilterableEntityIdentifier(selectedEntityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
|
||||
canvasManager.filter.startFilter(selectedEntityIdentifier);
|
||||
}, [canvasManager, selectedEntityIdentifier]);
|
||||
|
||||
if (!selectedEntityIdentifier) {
|
||||
return null;
|
||||
@@ -22,8 +37,8 @@ export const EntityListSelectedEntityActionBarFilterButton = memo(() => {
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={filter.start}
|
||||
isDisabled={filter.isDisabled}
|
||||
onClick={onClick}
|
||||
isDisabled={isBusy || isStaging}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
|
||||
@@ -1,16 +1,34 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityTransform } from 'features/controlLayers/hooks/useEntityTransform';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFrameCornersBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListSelectedEntityActionBarTransformButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const transform = useEntityTransform(selectedEntityIdentifier);
|
||||
const canvasManager = useCanvasManager();
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
if (!selectedEntityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isTransformableEntityIdentifier(selectedEntityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
const adapter = canvasManager.getAdapter(selectedEntityIdentifier);
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
adapter.transformer.startTransform();
|
||||
}, [canvasManager, selectedEntityIdentifier]);
|
||||
|
||||
if (!selectedEntityIdentifier) {
|
||||
return null;
|
||||
@@ -22,8 +40,8 @@ export const EntityListSelectedEntityActionBarTransformButton = memo(() => {
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={transform.start}
|
||||
isDisabled={transform.isDisabled}
|
||||
onClick={onClick}
|
||||
isDisabled={isBusy || isStaging}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Divider, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasAddEntityButtons } from 'features/controlLayers/components/CanvasAddEntityButtons';
|
||||
import { CanvasEntityList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityList';
|
||||
import { EntityListGlobalActionBar } from 'features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBar';
|
||||
import { EntityListSelectedEntityActionBar } from 'features/controlLayers/components/CanvasEntityList/EntityListSelectedEntityActionBar';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectHasEntities } from 'features/controlLayers/store/selectors';
|
||||
@@ -13,6 +14,8 @@ export const CanvasPanelContent = memo(() => {
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<Flex flexDir="column" gap={2} w="full" h="full">
|
||||
<EntityListGlobalActionBar />
|
||||
<Divider py={0} />
|
||||
<EntityListSelectedEntityActionBar />
|
||||
<Divider py={0} />
|
||||
{!hasEntities && <CanvasAddEntityButtons />}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
import { useDndContext } from '@dnd-kit/core';
|
||||
import { Box, Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useScopeOnFocus } from 'common/hooks/interactionScopes';
|
||||
import { CanvasPanelContent } from 'features/controlLayers/components/CanvasPanelContent';
|
||||
import { CanvasSendToToggle } from 'features/controlLayers/components/CanvasSendToToggle';
|
||||
import { selectSendToCanvas } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import GalleryPanelContent from 'features/gallery/components/GalleryPanelContent';
|
||||
import { memo, useCallback, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasRightPanelContent = memo(() => {
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
const [tab, setTab] = useState(0);
|
||||
useScopeOnFocus('gallery', ref);
|
||||
|
||||
return (
|
||||
<Tabs index={tab} onChange={setTab} w="full" h="full" display="flex" flexDir="column">
|
||||
<TabList alignItems="center">
|
||||
<PanelTabs setTab={setTab} />
|
||||
<Spacer />
|
||||
<CanvasSendToToggle />
|
||||
</TabList>
|
||||
<TabPanels w="full" h="full">
|
||||
<TabPanel w="full" h="full" p={0} pt={3}>
|
||||
<CanvasPanelContent />
|
||||
</TabPanel>
|
||||
<TabPanel w="full" h="full" p={0} pt={3}>
|
||||
<GalleryPanelContent />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasRightPanelContent.displayName = 'CanvasRightPanelContent';
|
||||
|
||||
const PanelTabs = memo(({ setTab }: { setTab: (val: number) => void }) => {
|
||||
const { t } = useTranslation();
|
||||
const sendToCanvas = useAppSelector(selectSendToCanvas);
|
||||
const tabTimeout = useRef<number | null>(null);
|
||||
const dndCtx = useDndContext();
|
||||
|
||||
const onOnMouseOverLayersTab = useCallback(() => {
|
||||
tabTimeout.current = window.setTimeout(() => {
|
||||
if (dndCtx.active) {
|
||||
setTab(1);
|
||||
}
|
||||
}, 300);
|
||||
}, [dndCtx.active, setTab]);
|
||||
|
||||
const onOnMouseOverGalleryTab = useCallback(() => {
|
||||
tabTimeout.current = window.setTimeout(() => {
|
||||
if (dndCtx.active) {
|
||||
setTab(0);
|
||||
}
|
||||
}, 300);
|
||||
}, [dndCtx.active, setTab]);
|
||||
|
||||
const onMouseOut = useCallback(() => {
|
||||
if (tabTimeout.current) {
|
||||
clearTimeout(tabTimeout.current);
|
||||
}
|
||||
}, []);
|
||||
return (
|
||||
<>
|
||||
<Tab position="relative" onMouseOver={onOnMouseOverLayersTab} onMouseOut={onMouseOut}>
|
||||
{t('controlLayers.layer_other')}
|
||||
{sendToCanvas && (
|
||||
<Box position="absolute" top={2} right={2} h={2} w={2} bg="invokeYellow.300" borderRadius="full" />
|
||||
)}
|
||||
</Tab>
|
||||
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut}>
|
||||
{t('gallery.gallery')}
|
||||
{!sendToCanvas && (
|
||||
<Box position="absolute" top={2} right={2} h={2} w={2} bg="invokeYellow.300" borderRadius="full" />
|
||||
)}
|
||||
</Tab>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
PanelTabs.displayName = 'PanelTabs';
|
||||
@@ -1,76 +1,64 @@
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Icon,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSendToCanvas, settingsSendToCanvasChanged } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { IconSwitch } from 'common/components/IconSwitch';
|
||||
import {
|
||||
selectCanvasSettingsSlice,
|
||||
settingsSendToCanvasChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold, PiCheckBold } from 'react-icons/pi';
|
||||
import { PiImageBold, PiPaintBrushBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasSendToToggle = memo(() => {
|
||||
const TooltipSendToGallery = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const sendToCanvas = useAppSelector(selectSendToCanvas);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const enableSendToCanvas = useCallback(() => {
|
||||
dispatch(settingsSendToCanvasChanged(true));
|
||||
}, [dispatch]);
|
||||
|
||||
const disableSendToCanvas = useCallback(() => {
|
||||
dispatch(settingsSendToCanvasChanged(false));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="link"
|
||||
data-testid="toggle-viewer-menu-button"
|
||||
pointerEvents="auto"
|
||||
rightIcon={<PiCaretDownBold />}
|
||||
>
|
||||
{sendToCanvas ? t('controlLayers.sendingToCanvas') : t('controlLayers.sendingToGallery')}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent p={2} pointerEvents="auto">
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<Flex flexDir="column">
|
||||
<Button onClick={disableSendToCanvas} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Flex gap={2} w="full">
|
||||
<Icon as={PiCheckBold} visibility={!sendToCanvas ? 'visible' : 'hidden'} />
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToGallery')}</Text>
|
||||
<Text fontWeight="normal" variant="subtext">
|
||||
{t('controlLayers.sendToGalleryDesc')}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Button>
|
||||
<Button onClick={enableSendToCanvas} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Flex gap={2} w="full">
|
||||
<Icon as={PiCheckBold} visibility={sendToCanvas ? 'visible' : 'hidden'} />
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToCanvas')}</Text>
|
||||
<Text fontWeight="normal" variant="subtext">
|
||||
{t('controlLayers.sendToCanvasDesc')}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Button>
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToGallery')}</Text>
|
||||
<Text fontWeight="normal">{t('controlLayers.sendToGalleryDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipSendToGallery.displayName = 'TooltipSendToGallery';
|
||||
|
||||
const TooltipSendToCanvas = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToCanvas')}</Text>
|
||||
<Text fontWeight="normal">{t('controlLayers.sendToCanvasDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipSendToCanvas.displayName = 'TooltipSendToCanvas';
|
||||
|
||||
const selectSendToCanvas = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.sendToCanvas);
|
||||
|
||||
export const CanvasSendToToggle = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const sendToCanvas = useAppSelector(selectSendToCanvas);
|
||||
|
||||
const onChange = useCallback(
|
||||
(isChecked: boolean) => {
|
||||
dispatch(settingsSendToCanvasChanged(isChecked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IconSwitch
|
||||
isChecked={sendToCanvas}
|
||||
onChange={onChange}
|
||||
iconUnchecked={<PiImageBold />}
|
||||
tooltipUnchecked={<TooltipSendToGallery />}
|
||||
iconChecked={<PiPaintBrushBold />}
|
||||
tooltipChecked={<TooltipSendToCanvas />}
|
||||
ariaLabel="Toggle canvas mode"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,35 +1,21 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
|
||||
import { Weight } from 'features/controlLayers/components/common/Weight';
|
||||
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
|
||||
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useControlLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
|
||||
import {
|
||||
controlLayerBeginEndStepPctChanged,
|
||||
controlLayerControlModeChanged,
|
||||
controlLayerModelChanged,
|
||||
controlLayerWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => {
|
||||
const selectControlAdapter = useMemo(
|
||||
() =>
|
||||
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
|
||||
const layer = selectEntityOrThrow(canvas, entityIdentifier);
|
||||
return layer.controlAdapter;
|
||||
}),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const controlAdapter = useAppSelector(selectControlAdapter);
|
||||
return controlAdapter;
|
||||
};
|
||||
|
||||
export const ControlLayerControlAdapter = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
IMAGE_FILTERS,
|
||||
isControlLayerEntityIdentifier,
|
||||
isFilterType,
|
||||
isRasterLayerEntityIdentifier,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -14,6 +22,8 @@ type Props = {
|
||||
|
||||
export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const canvasManager = useCanvasManager();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels();
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
@@ -24,8 +34,35 @@ export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onCha
|
||||
return;
|
||||
}
|
||||
onChangeModel(modelConfig);
|
||||
|
||||
// When we set the model for the first time, we'll set the default filter settings and open the filter popup
|
||||
|
||||
if (modelKey) {
|
||||
// If there is already a model key, this is not the first time we're setting the model
|
||||
return;
|
||||
}
|
||||
|
||||
// Open the filter popup by setting this entity as the filtering entity
|
||||
if (!canvasManager.filter.$adapter.get()) {
|
||||
// Can only filter raster and control layers
|
||||
if (!isRasterLayerEntityIdentifier(entityIdentifier) && !isControlLayerEntityIdentifier(entityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the filter, preferring the model's default
|
||||
if (isFilterType(modelConfig.default_settings?.preprocessor)) {
|
||||
canvasManager.filter.$config.set(
|
||||
IMAGE_FILTERS[modelConfig.default_settings.preprocessor].buildDefaults(modelConfig.base)
|
||||
);
|
||||
} else {
|
||||
canvasManager.filter.$config.set(IMAGE_FILTERS.canny_image_processor.buildDefaults(modelConfig.base));
|
||||
}
|
||||
|
||||
canvasManager.filter.startFilter(entityIdentifier);
|
||||
canvasManager.filter.previewFilter();
|
||||
}
|
||||
},
|
||||
[onChangeModel]
|
||||
[canvasManager.filter, entityIdentifier, modelKey, onChangeModel]
|
||||
);
|
||||
|
||||
const getIsDisabled = useCallback(
|
||||
|
||||
@@ -1,49 +1,37 @@
|
||||
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Heading, Spacer, Switch } from '@invoke-ai/ui-library';
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
|
||||
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import {
|
||||
selectAutoProcessFilter,
|
||||
settingsAutoProcessFilterToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { type FilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiCheckBold, PiShootingStarBold, PiXBold } from 'react-icons/pi';
|
||||
import { PiCheckBold, PiShootingStarBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
|
||||
export const Filter = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const config = useStore(adapter.filterer.$filterConfig);
|
||||
const isProcessing = useStore(adapter.filterer.$isProcessing);
|
||||
const hasProcessed = useStore(adapter.filterer.$hasProcessed);
|
||||
const autoProcessFilter = useAppSelector(selectAutoProcessFilter);
|
||||
const canvasManager = useCanvasManager();
|
||||
const config = useStore(canvasManager.filter.$config);
|
||||
const isFiltering = useStore(canvasManager.filter.$isFiltering);
|
||||
const isProcessing = useStore(canvasManager.filter.$isProcessing);
|
||||
|
||||
const onChangeFilterConfig = useCallback(
|
||||
(filterConfig: FilterConfig) => {
|
||||
adapter.filterer.$filterConfig.set(filterConfig);
|
||||
canvasManager.filter.$config.set(filterConfig);
|
||||
},
|
||||
[adapter.filterer.$filterConfig]
|
||||
[canvasManager.filter.$config]
|
||||
);
|
||||
|
||||
const onChangeFilterType = useCallback(
|
||||
(filterType: FilterConfig['type']) => {
|
||||
adapter.filterer.$filterConfig.set(IMAGE_FILTERS[filterType].buildDefaults());
|
||||
canvasManager.filter.$config.set(IMAGE_FILTERS[filterType].buildDefaults());
|
||||
},
|
||||
[adapter.filterer.$filterConfig]
|
||||
[canvasManager.filter.$config]
|
||||
);
|
||||
|
||||
const onChangeAutoProcessFilter = useCallback(() => {
|
||||
dispatch(settingsAutoProcessFilterToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
const isValid = useMemo(() => {
|
||||
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
|
||||
}, [config]);
|
||||
if (!isFiltering) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -58,56 +46,37 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
|
||||
transitionProperty="height"
|
||||
transitionDuration="normal"
|
||||
>
|
||||
<Flex w="full">
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.filter.filter')}
|
||||
</Heading>
|
||||
<Spacer />
|
||||
<FormControl w="min-content">
|
||||
<FormLabel m={0}>{t('controlLayers.filter.autoProcess')}</FormLabel>
|
||||
<Switch size="sm" isChecked={autoProcessFilter} onChange={onChangeAutoProcessFilter} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Heading size="md" color="base.300" userSelect="none">
|
||||
{t('controlLayers.filter.filter')}
|
||||
</Heading>
|
||||
<FilterTypeSelect filterType={config.type} onChange={onChangeFilterType} />
|
||||
<FilterSettings filterConfig={config} onChange={onChangeFilterConfig} />
|
||||
<ButtonGroup isAttached={false} size="sm" w="full">
|
||||
<Button
|
||||
variant="ghost"
|
||||
leftIcon={<PiShootingStarBold />}
|
||||
onClick={adapter.filterer.process}
|
||||
onClick={canvasManager.filter.previewFilter}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.process')}
|
||||
isDisabled={!isValid}
|
||||
loadingText={t('controlLayers.filter.preview')}
|
||||
>
|
||||
{t('controlLayers.filter.process')}
|
||||
{t('controlLayers.filter.preview')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
<Button
|
||||
leftIcon={<PiArrowsCounterClockwiseBold />}
|
||||
onClick={adapter.filterer.reset}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.reset')}
|
||||
variant="ghost"
|
||||
>
|
||||
{t('controlLayers.filter.reset')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
leftIcon={<PiCheckBold />}
|
||||
onClick={adapter.filterer.apply}
|
||||
onClick={canvasManager.filter.applyFilter}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.apply')}
|
||||
isDisabled={!isValid || !hasProcessed}
|
||||
>
|
||||
{t('controlLayers.filter.apply')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
leftIcon={<PiXBold />}
|
||||
onClick={adapter.filterer.cancel}
|
||||
onClick={canvasManager.filter.cancelFilter}
|
||||
isLoading={isProcessing}
|
||||
loadingText={t('controlLayers.filter.cancel')}
|
||||
isDisabled={!isValid}
|
||||
>
|
||||
{t('controlLayers.filter.cancel')}
|
||||
</Button>
|
||||
@@ -116,16 +85,4 @@ const FilterBox = memo(({ adapter }: { adapter: CanvasEntityAdapterRasterLayer |
|
||||
);
|
||||
});
|
||||
|
||||
FilterBox.displayName = 'FilterBox';
|
||||
|
||||
export const Filter = () => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useStore(canvasManager.stateApi.$filteringAdapter);
|
||||
if (!adapter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <FilterBox adapter={adapter} />;
|
||||
};
|
||||
|
||||
Filter.displayName = 'Filter';
|
||||
|
||||
@@ -10,7 +10,6 @@ import { FilterMediapipeFace } from 'features/controlLayers/components/Filters/F
|
||||
import { FilterMidasDepth } from 'features/controlLayers/components/Filters/FilterMidasDepth';
|
||||
import { FilterMlsdImage } from 'features/controlLayers/components/Filters/FilterMlsdImage';
|
||||
import { FilterPidi } from 'features/controlLayers/components/Filters/FilterPidi';
|
||||
import { FilterSpandrel } from 'features/controlLayers/components/Filters/FilterSpandrel';
|
||||
import type { FilterConfig } from 'features/controlLayers/store/types';
|
||||
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
|
||||
import { memo } from 'react';
|
||||
@@ -65,10 +64,6 @@ export const FilterSettings = memo(({ filterConfig, onChange }: Props) => {
|
||||
return <FilterPidi config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
if (filterConfig.type === 'spandrel_filter') {
|
||||
return <FilterSpandrel config={filterConfig} onChange={onChange} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<IAINoContentFallback
|
||||
py={4}
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import {
|
||||
Box,
|
||||
Combobox,
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Switch,
|
||||
Tooltip,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import type { SpandrelFilterConfig } from 'features/controlLayers/store/types';
|
||||
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FilterComponentProps } from './types';
|
||||
|
||||
type Props = FilterComponentProps<SpandrelFilterConfig>;
|
||||
const DEFAULTS = IMAGE_FILTERS['spandrel_filter'].buildDefaults();
|
||||
|
||||
export const FilterSpandrel = ({ onChange, config }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const tooltipLabel = useMemo(() => {
|
||||
if (!modelConfigs.length || !config.model) {
|
||||
return;
|
||||
}
|
||||
return modelConfigs.find((m) => m.key === config.model?.key)?.description;
|
||||
}, [modelConfigs, config.model]);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(v: SpandrelImageToImageModelConfig | null) => {
|
||||
onChange({ ...config, model: v });
|
||||
},
|
||||
[config, onChange]
|
||||
);
|
||||
|
||||
const {
|
||||
options,
|
||||
value,
|
||||
onChange: onChangeModel,
|
||||
placeholder,
|
||||
noOptionsMessage,
|
||||
} = useModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: config.model,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const onScaleChanged = useCallback(
|
||||
(v: number) => {
|
||||
onChange({ ...config, scale: v });
|
||||
},
|
||||
[onChange, config]
|
||||
);
|
||||
const onAutoscaleChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChange({ ...config, autoScale: e.target.checked });
|
||||
},
|
||||
[onChange, config]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!config.model) {
|
||||
onChangeModel(options[0] ?? null);
|
||||
}
|
||||
}, [config.model, onChangeModel, options]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<FormControl w="full" orientation="vertical">
|
||||
<Flex w="full" alignItems="center">
|
||||
<FormLabel m={0} flexGrow={1}>
|
||||
{t('controlLayers.filter.spandrel.paramAutoScale')}
|
||||
</FormLabel>
|
||||
<Switch size="sm" isChecked={config.autoScale} onChange={onAutoscaleChanged} />
|
||||
</Flex>
|
||||
<FormHelperText>{t('controlLayers.filter.spandrel.paramAutoScaleDesc')}</FormHelperText>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={!config.autoScale}>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.spandrel.paramScale')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={config.scale}
|
||||
onChange={onScaleChanged}
|
||||
defaultValue={DEFAULTS.scale}
|
||||
min={1}
|
||||
max={16}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={config.scale}
|
||||
onChange={onScaleChanged}
|
||||
defaultValue={DEFAULTS.scale}
|
||||
min={1}
|
||||
max={16}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel m={0}>{t('controlLayers.filter.spandrel.paramModel')}</FormLabel>
|
||||
<Tooltip label={tooltipLabel}>
|
||||
<Box w="full">
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChangeModel}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
isDisabled={options.length === 0}
|
||||
/>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
</FormControl>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
FilterSpandrel.displayName = 'FilterSpandrel';
|
||||
@@ -1,24 +0,0 @@
|
||||
import { Grid } from '@invoke-ai/ui-library';
|
||||
import { CanvasHUDItemBbox } from 'features/controlLayers/components/HUD/CanvasHUDItemBbox';
|
||||
import { CanvasHUDItemScaledBbox } from 'features/controlLayers/components/HUD/CanvasHUDItemScaledBbox';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const CanvasHUD = memo(() => {
|
||||
return (
|
||||
<Grid
|
||||
bg="base.900"
|
||||
borderBottomEndRadius="base"
|
||||
p={2}
|
||||
gap={1}
|
||||
borderRadius="base"
|
||||
templateColumns="1fr 1fr"
|
||||
opacity={0.6}
|
||||
minW={64}
|
||||
>
|
||||
<CanvasHUDItemBbox />
|
||||
<CanvasHUDItemScaledBbox />
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasHUD.displayName = 'CanvasHUD';
|
||||
@@ -1,26 +0,0 @@
|
||||
import { GridItem, Text } from '@invoke-ai/ui-library';
|
||||
import type { Property } from 'csstype';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
label: string;
|
||||
value: string | number;
|
||||
color?: Property.Color;
|
||||
};
|
||||
|
||||
export const CanvasHUDItem = memo(({ label, value, color }: Props) => {
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<Text textAlign="end">{label}: </Text>
|
||||
</GridItem>
|
||||
<GridItem fontWeight="semibold">
|
||||
<Text textAlign="end" color={color}>
|
||||
{value}
|
||||
</Text>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasHUDItem.displayName = 'CanvasHUDItem';
|
||||
@@ -1,17 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasHUDItem } from 'features/controlLayers/components/HUD/CanvasHUDItem';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectBboxRect = createSelector(selectBbox, (bbox) => bbox.rect);
|
||||
|
||||
export const CanvasHUDItemBbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const rect = useAppSelector(selectBboxRect);
|
||||
|
||||
return <CanvasHUDItem label={t('controlLayers.HUD.bbox')} value={`${rect.width}×${rect.height} px`} />;
|
||||
});
|
||||
|
||||
CanvasHUDItemBbox.displayName = 'CanvasHUDItemBbox';
|
||||
@@ -1,19 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasHUDItem } from 'features/controlLayers/components/HUD/CanvasHUDItem';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectScaledSize = createSelector(selectBbox, (bbox) => bbox.scaledSize);
|
||||
|
||||
export const CanvasHUDItemScaledBbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const scaledSize = useAppSelector(selectScaledSize);
|
||||
|
||||
return (
|
||||
<CanvasHUDItem label={t('controlLayers.HUD.scaledBbox')} value={`${scaledSize.width}×${scaledSize.height} px`} />
|
||||
);
|
||||
});
|
||||
|
||||
CanvasHUDItemScaledBbox.displayName = 'CanvasHUDItemScaledBbox';
|
||||
@@ -1,133 +0,0 @@
|
||||
import { Box, Flex, Icon, Text } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { Property } from 'csstype';
|
||||
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useEntityTitle } from 'features/controlLayers/hooks/useEntityTitle';
|
||||
import { useEntityTypeIsHidden } from 'features/controlLayers/hooks/useEntityTypeIsHidden';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import {
|
||||
selectCanvasSlice,
|
||||
selectEntityOrThrow,
|
||||
selectSelectedEntityIdentifier,
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiWarningCircleFill } from 'react-icons/pi';
|
||||
|
||||
type ContentProps = {
|
||||
entityIdentifier: CanvasEntityIdentifier;
|
||||
adapter: CanvasEntityAdapter;
|
||||
};
|
||||
|
||||
const $isFilteringFallback = atom(false);
|
||||
|
||||
type EntityStatus = {
|
||||
value: string;
|
||||
color?: Property.Color;
|
||||
};
|
||||
|
||||
const CanvasSelectedEntityStatusAlertContent = memo(({ entityIdentifier, adapter }: ContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
const title = useEntityTitle(entityIdentifier);
|
||||
const selectIsEnabled = useMemo(
|
||||
() => createSelector(selectCanvasSlice, (canvas) => selectEntityOrThrow(canvas, entityIdentifier).isEnabled),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const selectIsLocked = useMemo(
|
||||
() => createSelector(selectCanvasSlice, (canvas) => selectEntityOrThrow(canvas, entityIdentifier).isLocked),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const isEnabled = useAppSelector(selectIsEnabled);
|
||||
const isLocked = useAppSelector(selectIsLocked);
|
||||
const isHidden = useEntityTypeIsHidden(entityIdentifier.type);
|
||||
const isFiltering = useStore(adapter.filterer?.$isFiltering ?? $isFilteringFallback);
|
||||
const isTransforming = useStore(adapter.transformer.$isTransforming);
|
||||
|
||||
const status = useMemo<EntityStatus | null>(() => {
|
||||
if (isFiltering) {
|
||||
return {
|
||||
value: t('controlLayers.HUD.entityStatus.isFiltering'),
|
||||
color: 'invokeBlue.300',
|
||||
};
|
||||
}
|
||||
|
||||
if (isTransforming) {
|
||||
return {
|
||||
value: t('controlLayers.HUD.entityStatus.isTransforming'),
|
||||
color: 'invokeBlue.300',
|
||||
};
|
||||
}
|
||||
|
||||
if (isHidden) {
|
||||
return {
|
||||
value: t('controlLayers.HUD.entityStatus.isHidden'),
|
||||
color: 'invokePurple.300',
|
||||
};
|
||||
}
|
||||
|
||||
if (isLocked) {
|
||||
return {
|
||||
value: t('controlLayers.HUD.entityStatus.isLocked'),
|
||||
color: 'invokeRed.300',
|
||||
};
|
||||
}
|
||||
|
||||
if (!isEnabled) {
|
||||
return {
|
||||
value: t('controlLayers.HUD.entityStatus.isDisabled'),
|
||||
color: 'invokeRed.300',
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}, [isFiltering, isTransforming, isHidden, isLocked, isEnabled, t]);
|
||||
|
||||
if (!status) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box position="relative" shadow="dark-lg">
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
left={0}
|
||||
bottom={0}
|
||||
bg={status.color}
|
||||
opacity={0.3}
|
||||
borderRadius="base"
|
||||
borderColor="whiteAlpha.400"
|
||||
borderWidth={1}
|
||||
/>
|
||||
<Flex px={6} py={4} gap={6} alignItems="center" justifyContent="center">
|
||||
<Icon as={PiWarningCircleFill} />
|
||||
<Text as="span" h={8}>
|
||||
<Text as="span" fontWeight="semibold">
|
||||
{title}
|
||||
</Text>{' '}
|
||||
{status.value}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSelectedEntityStatusAlertContent.displayName = 'CanvasSelectedEntityStatusAlertContent';
|
||||
|
||||
export const CanvasSelectedEntityStatusAlert = memo(() => {
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const adapter = useEntityAdapterSafe(selectedEntityIdentifier);
|
||||
|
||||
if (!selectedEntityIdentifier || !adapter) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <CanvasSelectedEntityStatusAlertContent entityIdentifier={selectedEntityIdentifier} adapter={adapter} />;
|
||||
});
|
||||
|
||||
CanvasSelectedEntityStatusAlert.displayName = 'CanvasSelectedEntityStatusAlert';
|
||||
@@ -0,0 +1,43 @@
|
||||
import { Grid, GridItem, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectBbox = createSelector(selectCanvasSlice, (canvas) => canvas.bbox);
|
||||
|
||||
export const HeadsUpDisplay = memo(() => {
|
||||
const bbox = useAppSelector(selectBbox);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
bg="base.900"
|
||||
borderBottomEndRadius="base"
|
||||
p={2}
|
||||
gap={2}
|
||||
borderRadius="base"
|
||||
templateColumns="auto auto"
|
||||
opacity={0.6}
|
||||
>
|
||||
<HUDItem label="BBox" value={`${bbox.rect.width}×${bbox.rect.height} px`} />
|
||||
<HUDItem label="Scaled BBox" value={`${bbox.scaledSize.width}×${bbox.scaledSize.height} px`} />
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
HeadsUpDisplay.displayName = 'HeadsUpDisplay';
|
||||
|
||||
const HUDItem = memo(({ label, value }: { label: string; value: string | number }) => {
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<Text textAlign="end">{label}: </Text>
|
||||
</GridItem>
|
||||
<GridItem fontWeight="semibold">
|
||||
<Text>{value}</Text>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
HUDItem.displayName = 'HUDItem';
|
||||
@@ -85,7 +85,7 @@ export const IPAdapterImagePreview = memo(
|
||||
/>
|
||||
|
||||
{controlImage && (
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
|
||||
@@ -1,28 +1,42 @@
|
||||
import { Button, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import {
|
||||
buildSelectValidRegionalGuidanceActions,
|
||||
useAddRegionalGuidanceIPAdapter,
|
||||
useAddRegionalGuidanceNegativePrompt,
|
||||
useAddRegionalGuidancePositivePrompt,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useMemo } from 'react';
|
||||
rgIPAdapterAdded,
|
||||
rgNegativePromptChanged,
|
||||
rgPositivePromptChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
|
||||
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
|
||||
const { t } = useTranslation();
|
||||
const addRegionalGuidanceIPAdapter = useAddRegionalGuidanceIPAdapter(entityIdentifier);
|
||||
const addRegionalGuidancePositivePrompt = useAddRegionalGuidancePositivePrompt(entityIdentifier);
|
||||
const addRegionalGuidanceNegativePrompt = useAddRegionalGuidanceNegativePrompt(entityIdentifier);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const selectValidActions = useMemo(
|
||||
() => buildSelectValidRegionalGuidanceActions(entityIdentifier),
|
||||
() =>
|
||||
createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
const entity = selectEntityOrThrow(canvas, entityIdentifier);
|
||||
return {
|
||||
canAddPositivePrompt: entity?.positivePrompt === null,
|
||||
canAddNegativePrompt: entity?.negativePrompt === null,
|
||||
};
|
||||
}),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const validActions = useAppSelector(selectValidActions);
|
||||
const addPositivePrompt = useCallback(() => {
|
||||
dispatch(rgPositivePromptChanged({ entityIdentifier, prompt: '' }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
const addNegativePrompt = useCallback(() => {
|
||||
dispatch(rgNegativePromptChanged({ entityIdentifier, prompt: '' }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<Flex w="full" p={2} justifyContent="space-between">
|
||||
@@ -30,7 +44,7 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidancePositivePrompt}
|
||||
onClick={addPositivePrompt}
|
||||
isDisabled={!validActions.canAddPositivePrompt}
|
||||
>
|
||||
{t('common.positivePrompt')}
|
||||
@@ -39,12 +53,12 @@ export const RegionalGuidanceAddPromptsIPAdapterButtons = () => {
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidanceNegativePrompt}
|
||||
onClick={addNegativePrompt}
|
||||
isDisabled={!validActions.canAddNegativePrompt}
|
||||
>
|
||||
{t('common.negativePrompt')}
|
||||
</Button>
|
||||
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addRegionalGuidanceIPAdapter}>
|
||||
<Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('common.ipAdapter')}
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
@@ -1,38 +1,53 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import {
|
||||
buildSelectValidRegionalGuidanceActions,
|
||||
useAddRegionalGuidanceIPAdapter,
|
||||
useAddRegionalGuidanceNegativePrompt,
|
||||
useAddRegionalGuidancePositivePrompt,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { memo, useMemo } from 'react';
|
||||
import {
|
||||
rgIPAdapterAdded,
|
||||
rgNegativePromptChanged,
|
||||
rgPositivePromptChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const RegionalGuidanceMenuItemsAddPromptsAndIPAdapter = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext('regional_guidance');
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const addRegionalGuidanceIPAdapter = useAddRegionalGuidanceIPAdapter(entityIdentifier);
|
||||
const addRegionalGuidancePositivePrompt = useAddRegionalGuidancePositivePrompt(entityIdentifier);
|
||||
const addRegionalGuidanceNegativePrompt = useAddRegionalGuidanceNegativePrompt(entityIdentifier);
|
||||
const selectValidActions = useMemo(
|
||||
() => buildSelectValidRegionalGuidanceActions(entityIdentifier),
|
||||
() =>
|
||||
createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
const entity = selectEntity(canvas, entityIdentifier);
|
||||
return {
|
||||
canAddPositivePrompt: entity?.positivePrompt === null,
|
||||
canAddNegativePrompt: entity?.negativePrompt === null,
|
||||
};
|
||||
}),
|
||||
[entityIdentifier]
|
||||
);
|
||||
const validActions = useAppSelector(selectValidActions);
|
||||
const addPositivePrompt = useCallback(() => {
|
||||
dispatch(rgPositivePromptChanged({ entityIdentifier, prompt: '' }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
const addNegativePrompt = useCallback(() => {
|
||||
dispatch(rgNegativePromptChanged({ entityIdentifier, prompt: '' }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem onClick={addRegionalGuidancePositivePrompt} isDisabled={!validActions.canAddPositivePrompt || isBusy}>
|
||||
<MenuItem onClick={addPositivePrompt} isDisabled={!validActions.canAddPositivePrompt || isBusy}>
|
||||
{t('controlLayers.addPositivePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={addRegionalGuidanceNegativePrompt} isDisabled={!validActions.canAddNegativePrompt || isBusy}>
|
||||
<MenuItem onClick={addNegativePrompt} isDisabled={!validActions.canAddNegativePrompt || isBusy}>
|
||||
{t('controlLayers.addNegativePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem onClick={addRegionalGuidanceIPAdapter} isDisabled={isBusy}>
|
||||
<MenuItem onClick={addIPAdapter} isDisabled={isBusy}>
|
||||
{t('controlLayers.addIPAdapter')}
|
||||
</MenuItem>
|
||||
</>
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAutoSave, settingsAutoSaveToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectCanvasSettingsSlice, settingsAutoSaveToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectAutoSave = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.autoSave);
|
||||
|
||||
export const CanvasSettingsAutoSaveCheckbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const autoSave = useAppSelector(selectAutoSave);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsAutoSaveToggled());
|
||||
}, [dispatch]);
|
||||
const onChange = useCallback(() => dispatch(settingsAutoSaveToggled()), [dispatch]);
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.autoSave')}</FormLabel>
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSnapToGrid, settingsSnapToGridToggled } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsSnapToGridCheckbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const snapToGrid = useAppSelector(selectSnapToGrid);
|
||||
const onChange = useCallback<ChangeEventHandler<HTMLInputElement>>(() => {
|
||||
dispatch(settingsSnapToGridToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.settings.snapToGrid.label')}</FormLabel>
|
||||
<Checkbox isChecked={snapToGrid} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsSnapToGridCheckbox.displayName = 'CanvasSettingsSnapToGrid';
|
||||
@@ -15,7 +15,6 @@ import { CanvasSettingsClearHistoryButton } from 'features/controlLayers/compone
|
||||
import { CanvasSettingsClipToBboxCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsClipToBboxCheckbox';
|
||||
import { CanvasSettingsCompositeMaskedRegionsCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsCompositeMaskedRegionsCheckbox';
|
||||
import { CanvasSettingsDynamicGridSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsDynamicGridSwitch';
|
||||
import { CanvasSettingsSnapToGridCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsGridSize';
|
||||
import { CanvasSettingsInvertScrollCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsInvertScrollCheckbox';
|
||||
import { CanvasSettingsLogDebugInfoButton } from 'features/controlLayers/components/Settings/CanvasSettingsLogDebugInfo';
|
||||
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
|
||||
@@ -40,7 +39,6 @@ export const CanvasSettingsPopover = memo(() => {
|
||||
<CanvasSettingsInvertScrollCheckbox />
|
||||
<CanvasSettingsClipToBboxCheckbox />
|
||||
<CanvasSettingsCompositeMaskedRegionsCheckbox />
|
||||
<CanvasSettingsSnapToGridCheckbox />
|
||||
<CanvasSettingsDynamicGridSwitch />
|
||||
<CanvasSettingsShowHUDSwitch />
|
||||
<CanvasSettingsResetButton />
|
||||
|
||||
@@ -4,9 +4,7 @@ import { $socket } from 'app/hooks/useSocketIO';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasHUD } from 'features/controlLayers/components/HUD/CanvasHUD';
|
||||
import { CanvasSelectedEntityStatusAlert } from 'features/controlLayers/components/HUD/CanvasSelectedEntityStatusAlert';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { HeadsUpDisplay } from 'features/controlLayers/components/HeadsUpDisplay';
|
||||
import { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { TRANSPARENCY_CHECKERBOARD_PATTERN_DATAURL } from 'features/controlLayers/konva/patterns/transparency-checkerboard-pattern';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
@@ -99,16 +97,11 @@ export const StageComponent = memo(() => {
|
||||
overflow="hidden"
|
||||
data-testid="control-layers-canvas"
|
||||
/>
|
||||
<CanvasManagerProviderGate>
|
||||
{showHUD && (
|
||||
<Flex position="absolute" top={1} insetInlineStart={1} pointerEvents="none">
|
||||
<CanvasHUD />
|
||||
</Flex>
|
||||
)}
|
||||
<Flex position="absolute" top={1} insetInlineEnd={1} pointerEvents="none">
|
||||
<CanvasSelectedEntityStatusAlert />
|
||||
{showHUD && (
|
||||
<Flex position="absolute" top={1} insetInlineStart={1} pointerEvents="none">
|
||||
<HeadsUpDisplay />
|
||||
</Flex>
|
||||
</CanvasManagerProviderGate>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { INTERACTION_SCOPES, useScopeOnMount } from 'common/hooks/interactionScopes';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import {
|
||||
selectCanvasStagingAreaSlice,
|
||||
stagingAreaImageAccepted,
|
||||
stagingAreaNextStagedImageSelected,
|
||||
stagingAreaPrevStagedImageSelected,
|
||||
stagingAreaReset,
|
||||
stagingAreaStagedImageDiscarded,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
selectCanvasSessionSlice,
|
||||
sessionNextStagedImageSelected,
|
||||
sessionPrevStagedImageSelected,
|
||||
sessionStagedImageDiscarded,
|
||||
sessionStagingAreaImageAccepted,
|
||||
sessionStagingAreaReset,
|
||||
} from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -28,16 +28,16 @@ import {
|
||||
import { useChangeImageIsIntermediateMutation } from 'services/api/endpoints/images';
|
||||
|
||||
const selectStagedImageIndex = createSelector(
|
||||
selectCanvasStagingAreaSlice,
|
||||
(stagingArea) => stagingArea.selectedStagedImageIndex
|
||||
selectCanvasSessionSlice,
|
||||
(canvasSession) => canvasSession.selectedStagedImageIndex
|
||||
);
|
||||
|
||||
const selectSelectedImage = createSelector(
|
||||
[selectCanvasStagingAreaSlice, selectStagedImageIndex],
|
||||
(stagingArea, index) => stagingArea.stagedImages[index] ?? null
|
||||
[selectCanvasSessionSlice, selectStagedImageIndex],
|
||||
(canvasSession, index) => canvasSession.stagedImages[index] ?? null
|
||||
);
|
||||
|
||||
const selectImageCount = createSelector(selectCanvasStagingAreaSlice, (stagingArea) => stagingArea.stagedImages.length);
|
||||
const selectImageCount = createSelector(selectCanvasSessionSlice, (canvasSession) => canvasSession.stagedImages.length);
|
||||
|
||||
export const StagingAreaToolbar = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -53,18 +53,18 @@ export const StagingAreaToolbar = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onPrev = useCallback(() => {
|
||||
dispatch(stagingAreaPrevStagedImageSelected());
|
||||
dispatch(sessionPrevStagedImageSelected());
|
||||
}, [dispatch]);
|
||||
|
||||
const onNext = useCallback(() => {
|
||||
dispatch(stagingAreaNextStagedImageSelected());
|
||||
dispatch(sessionNextStagedImageSelected());
|
||||
}, [dispatch]);
|
||||
|
||||
const onAccept = useCallback(() => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
dispatch(stagingAreaImageAccepted({ index }));
|
||||
dispatch(sessionStagingAreaImageAccepted({ index }));
|
||||
}, [dispatch, index, selectedImage]);
|
||||
|
||||
const onDiscardOne = useCallback(() => {
|
||||
@@ -72,14 +72,14 @@ export const StagingAreaToolbar = memo(() => {
|
||||
return;
|
||||
}
|
||||
if (imageCount === 1) {
|
||||
dispatch(stagingAreaReset());
|
||||
dispatch(sessionStagingAreaReset());
|
||||
} else {
|
||||
dispatch(stagingAreaStagedImageDiscarded({ index }));
|
||||
dispatch(sessionStagedImageDiscarded({ index }));
|
||||
}
|
||||
}, [selectedImage, imageCount, dispatch, index]);
|
||||
|
||||
const onDiscardAll = useCallback(() => {
|
||||
dispatch(stagingAreaReset());
|
||||
dispatch(sessionStagingAreaReset());
|
||||
}, [dispatch]);
|
||||
|
||||
const onToggleShouldShowStagedImage = useCallback(() => {
|
||||
|
||||
@@ -4,7 +4,6 @@ import { CanvasSettingsPopover } from 'features/controlLayers/components/Setting
|
||||
import { ToolChooser } from 'features/controlLayers/components/Tool/ToolChooser';
|
||||
import { ToolColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker';
|
||||
import { ToolSettings } from 'features/controlLayers/components/Tool/ToolSettings';
|
||||
import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton';
|
||||
import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton';
|
||||
import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton';
|
||||
import { CanvasToolbarScale } from 'features/controlLayers/components/Toolbar/CanvasToolbarScale';
|
||||
@@ -14,6 +13,8 @@ import { useCanvasEntityQuickSwitchHotkey } from 'features/controlLayers/hooks/u
|
||||
import { useCanvasResetLayerHotkey } from 'features/controlLayers/hooks/useCanvasResetLayerHotkey';
|
||||
import { useCanvasUndoRedoHotkeys } from 'features/controlLayers/hooks/useCanvasUndoRedoHotkeys';
|
||||
import { useNextPrevEntityHotkeys } from 'features/controlLayers/hooks/useNextPrevEntity';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { ViewerToggle } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const CanvasToolbar = memo(() => {
|
||||
@@ -26,6 +27,7 @@ export const CanvasToolbar = memo(() => {
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<Flex w="full" gap={2} alignItems="center">
|
||||
<ToggleProgressButton />
|
||||
<ToolChooser />
|
||||
<Spacer />
|
||||
<ToolSettings />
|
||||
@@ -34,9 +36,9 @@ export const CanvasToolbar = memo(() => {
|
||||
<CanvasToolbarResetViewButton />
|
||||
<Spacer />
|
||||
<ToolColorPicker />
|
||||
<CanvasToolbarFitBboxToLayersButton />
|
||||
<CanvasToolbarSaveToGalleryButton />
|
||||
<CanvasSettingsPopover />
|
||||
<ViewerToggle />
|
||||
</Flex>
|
||||
</CanvasManagerProviderGate>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { $alt, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { INTERACTION_SCOPES } from 'common/hooks/interactionScopes';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Button, ButtonGroup, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntityAdapter/types';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiArrowsOutBold, PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -17,31 +18,26 @@ type Props = {
|
||||
|
||||
export const CanvasEntityAddOfTypeButton = memo(({ type }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const addInpaintMask = useAddInpaintMask();
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const onClick = useCallback(() => {
|
||||
switch (type) {
|
||||
case 'inpaint_mask':
|
||||
addInpaintMask();
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'regional_guidance':
|
||||
addRegionalGuidance();
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'raster_layer':
|
||||
addRasterLayer();
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'control_layer':
|
||||
addControlLayer();
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
break;
|
||||
case 'ip_adapter':
|
||||
addIPAdapter();
|
||||
dispatch(ipaAdded({ isSelected: true }));
|
||||
break;
|
||||
}
|
||||
}, [addControlLayer, addIPAdapter, addInpaintMask, addRasterLayer, addRegionalGuidance, type]);
|
||||
}, [dispatch, type]);
|
||||
|
||||
const label = useMemo(() => {
|
||||
switch (type) {
|
||||
|
||||
@@ -27,7 +27,7 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex w="full">
|
||||
<Flex w="full" px={1}>
|
||||
<Flex
|
||||
flexGrow={1}
|
||||
as={Button}
|
||||
|
||||
@@ -1,17 +1,33 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityFilter } from 'features/controlLayers/hooks/useEntityFilter';
|
||||
import { memo } from 'react';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiShootingStarBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityMenuItemsFilter = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const canvasManager = useCanvasManager();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const filter = useEntityFilter(entityIdentifier);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
if (!entityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isFilterableEntityIdentifier(entityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
canvasManager.filter.startFilter(entityIdentifier);
|
||||
}, [canvasManager.filter, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={filter.start} icon={<PiShootingStarBold />} isDisabled={filter.isDisabled}>
|
||||
<MenuItem onClick={onClick} icon={<PiShootingStarBold />} isDisabled={isBusy || isStaging}>
|
||||
{t('controlLayers.filter.filter')}
|
||||
</MenuItem>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,30 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { useEntityTransform } from 'features/controlLayers/hooks/useEntityTransform';
|
||||
import { memo } from 'react';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useEntityAdapter } from 'features/controlLayers/hooks/useEntityAdapter';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFrameCornersBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityMenuItemsTransform = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
const transform = useEntityTransform(entityIdentifier);
|
||||
const adapter = useEntityAdapter(entityIdentifier);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
if (!isTransformableEntityIdentifier(entityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
adapter.transformer.startTransform();
|
||||
}, [adapter.transformer, entityIdentifier]);
|
||||
|
||||
return (
|
||||
<MenuItem onClick={transform.start} icon={<PiFrameCornersBold />} isDisabled={transform.isDisabled}>
|
||||
<MenuItem onClick={onClick} icon={<PiFrameCornersBold />} isDisabled={isBusy || isStaging}>
|
||||
{t('controlLayers.transform.transform')}
|
||||
</MenuItem>
|
||||
);
|
||||
|
||||
@@ -5,8 +5,8 @@ import { isOk, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useEntityTypeCount } from 'features/controlLayers/hooks/useEntityTypeCount';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { inpaintMaskAdded, rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -43,7 +43,7 @@ export const CanvasEntityMergeVisibleButton = memo(({ type }: Props) => {
|
||||
objects: [imageDTOToImageObject(result.value)],
|
||||
position: { x: Math.floor(rect.x), y: Math.floor(rect.y) },
|
||||
},
|
||||
isMergingVisible: true,
|
||||
deleteOthers: true,
|
||||
})
|
||||
);
|
||||
toast({ title: t('controlLayers.mergeVisibleOk') });
|
||||
@@ -65,7 +65,7 @@ export const CanvasEntityMergeVisibleButton = memo(({ type }: Props) => {
|
||||
objects: [imageDTOToImageObject(result.value)],
|
||||
position: { x: Math.floor(rect.x), y: Math.floor(rect.y) },
|
||||
},
|
||||
isMergingVisible: true,
|
||||
deleteOthers: true,
|
||||
})
|
||||
);
|
||||
toast({ title: t('controlLayers.mergeVisibleOk') });
|
||||
|
||||
@@ -37,20 +37,13 @@ export const CanvasEntityPreviewImage = memo(() => {
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||
const cache = useStore(adapter.renderer.$canvasCache);
|
||||
useEffect(() => {
|
||||
if (!canvasRef.current) {
|
||||
if (!cache || !canvasRef.current) {
|
||||
return;
|
||||
}
|
||||
const ctx = canvasRef.current.getContext('2d');
|
||||
if (!ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!cache) {
|
||||
// Draw an empty canvas
|
||||
ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
|
||||
return;
|
||||
}
|
||||
|
||||
const { rect, canvas } = cache;
|
||||
|
||||
ctx.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
|
||||
@@ -88,7 +81,6 @@ export const CanvasEntityPreviewImage = memo(() => {
|
||||
borderRadius="sm"
|
||||
borderWidth={1}
|
||||
bg="base.900"
|
||||
flexShrink={0}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import { $canvasManager, type CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useContext } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
@@ -106,51 +105,3 @@ export const useEntityAdapter = ():
|
||||
assert(adapter, 'useEntityAdapter must be used within a CanvasRasterLayerAdapterGate');
|
||||
return adapter;
|
||||
};
|
||||
|
||||
export const useEntityAdapterSafe = (
|
||||
entityIdentifier: CanvasEntityIdentifier | null
|
||||
):
|
||||
| CanvasEntityAdapterRasterLayer
|
||||
| CanvasEntityAdapterControlLayer
|
||||
| CanvasEntityAdapterInpaintMask
|
||||
| CanvasEntityAdapterRegionalGuidance
|
||||
| null => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const regionalGuidanceAdapters = useSyncExternalStore(
|
||||
canvasManager.adapters.regionMasks.subscribe,
|
||||
canvasManager.adapters.regionMasks.getSnapshot
|
||||
);
|
||||
const rasterLayerAdapters = useSyncExternalStore(
|
||||
canvasManager.adapters.rasterLayers.subscribe,
|
||||
canvasManager.adapters.rasterLayers.getSnapshot
|
||||
);
|
||||
const controlLayerAdapters = useSyncExternalStore(
|
||||
canvasManager.adapters.controlLayers.subscribe,
|
||||
canvasManager.adapters.controlLayers.getSnapshot
|
||||
);
|
||||
const inpaintMaskAdapters = useSyncExternalStore(
|
||||
canvasManager.adapters.inpaintMasks.subscribe,
|
||||
canvasManager.adapters.inpaintMasks.getSnapshot
|
||||
);
|
||||
|
||||
const adapter = useMemo(() => {
|
||||
if (!entityIdentifier) {
|
||||
return null;
|
||||
}
|
||||
if (entityIdentifier.type === 'raster_layer') {
|
||||
return rasterLayerAdapters.get(entityIdentifier.id) ?? null;
|
||||
}
|
||||
if (entityIdentifier.type === 'control_layer') {
|
||||
return controlLayerAdapters.get(entityIdentifier.id) ?? null;
|
||||
}
|
||||
if (entityIdentifier.type === 'inpaint_mask') {
|
||||
return inpaintMaskAdapters.get(entityIdentifier.id) ?? null;
|
||||
}
|
||||
if (entityIdentifier.type === 'regional_guidance') {
|
||||
return regionalGuidanceAdapters.get(entityIdentifier.id) ?? null;
|
||||
}
|
||||
return null;
|
||||
}, [controlLayerAdapters, entityIdentifier, inpaintMaskAdapters, rasterLayerAdapters, regionalGuidanceAdapters]);
|
||||
|
||||
return adapter;
|
||||
};
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
rgIPAdapterAdded,
|
||||
rgNegativePromptChanged,
|
||||
rgPositivePromptChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { useCallback } from 'react';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
selectModelConfigsQuery,
|
||||
selectBase,
|
||||
(query, base): ControlNetConfig | T2IAdapterConfig => {
|
||||
const { data } = query;
|
||||
let model: ControlNetModelConfig | T2IAdapterModelConfig | null = null;
|
||||
if (data) {
|
||||
const modelConfigs = modelConfigsAdapterSelectors
|
||||
.selectAll(data)
|
||||
.filter(isControlNetOrT2IAdapterModelConfig)
|
||||
.sort((a) => (a.type === 'controlnet' ? -1 : 1)); // Prefer ControlNet models
|
||||
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
|
||||
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
|
||||
}
|
||||
const controlAdapter = model?.type === 't2i_adapter' ? deepClone(initialT2IAdapter) : deepClone(initialControlNet);
|
||||
if (model) {
|
||||
controlAdapter.model = zModelIdentifierField.parse(model);
|
||||
}
|
||||
return controlAdapter;
|
||||
}
|
||||
);
|
||||
|
||||
const selectDefaultIPAdapter = createSelector(selectModelConfigsQuery, selectBase, (query, base): IPAdapterConfig => {
|
||||
const { data } = query;
|
||||
let model: IPAdapterModelConfig | null = null;
|
||||
if (data) {
|
||||
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
|
||||
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
|
||||
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
|
||||
}
|
||||
const ipAdapter = deepClone(initialIPAdapter);
|
||||
if (model) {
|
||||
ipAdapter.model = zModelIdentifierField.parse(model);
|
||||
}
|
||||
return ipAdapter;
|
||||
});
|
||||
|
||||
export const useAddControlLayer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
|
||||
const addControlLayer = useCallback(() => {
|
||||
const overrides = { controlAdapter: defaultControlAdapter };
|
||||
dispatch(controlLayerAdded({ isSelected: true, overrides }));
|
||||
}, [defaultControlAdapter, dispatch]);
|
||||
|
||||
return addControlLayer;
|
||||
};
|
||||
|
||||
export const useAddRasterLayer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const addRasterLayer = useCallback(() => {
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
|
||||
return addRasterLayer;
|
||||
};
|
||||
|
||||
export const useAddInpaintMask = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const addInpaintMask = useCallback(() => {
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
|
||||
return addInpaintMask;
|
||||
};
|
||||
|
||||
export const useAddRegionalGuidance = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const addRegionalGuidance = useCallback(() => {
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
|
||||
return addRegionalGuidance;
|
||||
};
|
||||
|
||||
export const useAddIPAdapter = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const addControlLayer = useCallback(() => {
|
||||
const overrides = { ipAdapter: defaultIPAdapter };
|
||||
dispatch(ipaAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
|
||||
return addControlLayer;
|
||||
};
|
||||
|
||||
export const useAddRegionalGuidanceIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
|
||||
const addRegionalGuidanceIPAdapter = useCallback(() => {
|
||||
dispatch(rgIPAdapterAdded({ entityIdentifier, overrides: defaultIPAdapter }));
|
||||
}, [defaultIPAdapter, dispatch, entityIdentifier]);
|
||||
|
||||
return addRegionalGuidanceIPAdapter;
|
||||
};
|
||||
|
||||
export const useAddRegionalGuidancePositivePrompt = (entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const addRegionalGuidancePositivePrompt = useCallback(() => {
|
||||
dispatch(rgPositivePromptChanged({ entityIdentifier, prompt: '' }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return addRegionalGuidancePositivePrompt;
|
||||
};
|
||||
|
||||
export const useAddRegionalGuidanceNegativePrompt = (entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const addRegionalGuidanceNegativePrompt = useCallback(() => {
|
||||
dispatch(rgNegativePromptChanged({ entityIdentifier, prompt: '' }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
return addRegionalGuidanceNegativePrompt;
|
||||
};
|
||||
|
||||
export const buildSelectValidRegionalGuidanceActions = (
|
||||
entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>
|
||||
) => {
|
||||
return createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
||||
const entity = selectEntityOrThrow(canvas, entityIdentifier);
|
||||
return {
|
||||
canAddPositivePrompt: entity?.positivePrompt === null,
|
||||
canAddNegativePrompt: entity?.negativePrompt === null,
|
||||
};
|
||||
});
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { entityDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
@@ -3,7 +3,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { entityReset } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { isMaskEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
@@ -25,8 +24,8 @@ export function useCanvasResetLayerHotkey() {
|
||||
}, [dispatch, selectedEntityIdentifier]);
|
||||
|
||||
const isResetEnabled = useMemo(
|
||||
() => selectedEntityIdentifier !== null && isMaskEntityIdentifier(selectedEntityIdentifier),
|
||||
[selectedEntityIdentifier]
|
||||
() => selectedEntityIdentifier?.type === 'inpaint_mask',
|
||||
[selectedEntityIdentifier?.type]
|
||||
);
|
||||
|
||||
useHotkeys('shift+c', resetSelectedLayer, { enabled: isResetEnabled }, [isResetEnabled, resetSelectedLayer]);
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterControlLayer';
|
||||
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterInpaintMask';
|
||||
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRasterLayer';
|
||||
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntityAdapter/CanvasEntityAdapterRegionalGuidance';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const useEntityAdapter = (
|
||||
entityIdentifier: CanvasEntityIdentifier
|
||||
):
|
||||
| CanvasEntityAdapterRasterLayer
|
||||
| CanvasEntityAdapterControlLayer
|
||||
| CanvasEntityAdapterInpaintMask
|
||||
| CanvasEntityAdapterRegionalGuidance => {
|
||||
const canvasManager = useCanvasManager();
|
||||
|
||||
const adapter = useMemo(() => {
|
||||
const adapter = canvasManager.getAdapter(entityIdentifier);
|
||||
assert(adapter, 'Entity adapter not found');
|
||||
return adapter;
|
||||
}, [canvasManager, entityIdentifier]);
|
||||
|
||||
return adapter;
|
||||
};
|
||||
@@ -1,75 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { AnyObjectRenderer } from 'features/controlLayers/konva/CanvasObject/types';
|
||||
import { getEmptyRect } from 'features/controlLayers/konva/util';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
|
||||
import { isFilterableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useMemo, useSyncExternalStore } from 'react';
|
||||
|
||||
// When the entity is empty (the rect has no size) or there are no renderers, we have nothing to filter. Because the
|
||||
// entity is dynamic, and we need reactivity on these values, we need to do a little hack. These fallback objects
|
||||
// can be used to provide a default value for the useStore and useSyncExternalStore hooks, which require _some_ value
|
||||
// to be used.
|
||||
const $fallbackPixelRect = atom<Rect>(getEmptyRect());
|
||||
const fallbackRenderersMap = new SyncableMap<string, AnyObjectRenderer>();
|
||||
|
||||
export const useEntityFilter = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
// Use the fallback pixel rect if the adapter is not available
|
||||
const pixelRect = useStore(adapter?.transformer.$pixelRect ?? $fallbackPixelRect);
|
||||
// Use the fallback renderers map if the adapter is not available
|
||||
const renderers = useSyncExternalStore(
|
||||
adapter?.renderer.renderers.subscribe ?? fallbackRenderersMap.subscribe,
|
||||
adapter?.renderer.renderers.getSnapshot ?? fallbackRenderersMap.getSnapshot
|
||||
);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
if (!entityIdentifier) {
|
||||
return true;
|
||||
}
|
||||
if (!isFilterableEntityIdentifier(entityIdentifier)) {
|
||||
return true;
|
||||
}
|
||||
if (!adapter) {
|
||||
return true;
|
||||
}
|
||||
if (isBusy || isStaging) {
|
||||
return true;
|
||||
}
|
||||
if (pixelRect.width === 0 || pixelRect.height === 0) {
|
||||
return true;
|
||||
}
|
||||
if (renderers.size === 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}, [entityIdentifier, adapter, isBusy, isStaging, pixelRect.width, pixelRect.height, renderers.size]);
|
||||
|
||||
const start = useCallback(() => {
|
||||
if (isDisabled) {
|
||||
return;
|
||||
}
|
||||
if (!entityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isFilterableEntityIdentifier(entityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
const adapter = canvasManager.getAdapter(entityIdentifier);
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
adapter.filterer.start();
|
||||
}, [isDisabled, entityIdentifier, canvasManager]);
|
||||
|
||||
return { isDisabled, start } as const;
|
||||
};
|
||||
@@ -1,72 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { SyncableMap } from 'common/util/SyncableMap/SyncableMap';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useEntityAdapterSafe } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import type { AnyObjectRenderer } from 'features/controlLayers/konva/CanvasObject/types';
|
||||
import { getEmptyRect } from 'features/controlLayers/konva/util';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { CanvasEntityIdentifier, Rect } from 'features/controlLayers/store/types';
|
||||
import { isTransformableEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useMemo, useSyncExternalStore } from 'react';
|
||||
|
||||
// When the entity is empty (the rect has no size) or there are no renderers, we have nothing to transform. Because the
|
||||
// entity is dynamic, and we need reactivity on these values, we need to do a little hack. These fallback objects
|
||||
// can be used to provide a default value for the useStore and useSyncExternalStore hooks, which require _some_ value
|
||||
// to be used.
|
||||
const $fallbackPixelRect = atom<Rect>(getEmptyRect());
|
||||
const fallbackRenderersMap = new SyncableMap<string, AnyObjectRenderer>();
|
||||
|
||||
export const useEntityTransform = (entityIdentifier: CanvasEntityIdentifier | null) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const adapter = useEntityAdapterSafe(entityIdentifier);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isBusy = useCanvasIsBusy();
|
||||
// Use the fallback pixel rect if the adapter is not available
|
||||
const pixelRect = useStore(adapter?.transformer.$pixelRect ?? $fallbackPixelRect);
|
||||
// Use the fallback renderers map if the adapter is not available
|
||||
const renderers = useSyncExternalStore(
|
||||
adapter?.renderer.renderers.subscribe ?? fallbackRenderersMap.subscribe,
|
||||
adapter?.renderer.renderers.getSnapshot ?? fallbackRenderersMap.getSnapshot
|
||||
);
|
||||
|
||||
const start = useCallback(() => {
|
||||
if (!entityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isTransformableEntityIdentifier(entityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
const adapter = canvasManager.getAdapter(entityIdentifier);
|
||||
if (!adapter) {
|
||||
return;
|
||||
}
|
||||
adapter.transformer.startTransform();
|
||||
}, [entityIdentifier, canvasManager]);
|
||||
|
||||
const isDisabled = useMemo(() => {
|
||||
if (!entityIdentifier) {
|
||||
return true;
|
||||
}
|
||||
if (!isTransformableEntityIdentifier(entityIdentifier)) {
|
||||
return true;
|
||||
}
|
||||
if (!adapter) {
|
||||
return true;
|
||||
}
|
||||
if (isBusy || isStaging) {
|
||||
return true;
|
||||
}
|
||||
if (pixelRect.width === 0 || pixelRect.height === 0) {
|
||||
return true;
|
||||
}
|
||||
if (renderers.size === 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}, [entityIdentifier, adapter, isBusy, isStaging, pixelRect.width, pixelRect.height, renderers.size]);
|
||||
|
||||
return { isDisabled, start } as const;
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user