Compare commits
994 Commits
ryan/cloth
...
v4.2.9.dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c5abd44a7 | ||
|
|
765d99ac2f | ||
|
|
ac9a66a628 | ||
|
|
0ea88dc170 | ||
|
|
8369826d22 | ||
|
|
0e354f5164 | ||
|
|
41f2ee2633 | ||
|
|
4e74006c5f | ||
|
|
48edb6e023 | ||
|
|
aeae6af0a1 | ||
|
|
ab11d9af8e | ||
|
|
2e84327ca4 | ||
|
|
fa6842121c | ||
|
|
c402aa397d | ||
|
|
a58c8adc38 | ||
|
|
d43e2d690e | ||
|
|
284f768810 | ||
|
|
e933d1ae2b | ||
|
|
1e134de771 | ||
|
|
29c47c8be5 | ||
|
|
e1122c541d | ||
|
|
2f81d1ac83 | ||
|
|
56fbe751db | ||
|
|
93f1d67fbf | ||
|
|
9467b937ff | ||
|
|
4242e6e6c2 | ||
|
|
9b39452b3e | ||
|
|
85b23784cf | ||
|
|
085cc82926 | ||
|
|
0098c33f81 | ||
|
|
292e00ab68 | ||
|
|
6c1fb2d06e | ||
|
|
d60605fcd8 | ||
|
|
38ed720ff2 | ||
|
|
22203b8eb0 | ||
|
|
cf5fa792a1 | ||
|
|
c636633a8e | ||
|
|
55fe1ebc53 | ||
|
|
3c2fa6b475 | ||
|
|
9b927de2e0 | ||
|
|
6a62854e7d | ||
|
|
312093cbb0 | ||
|
|
06fe14e1fc | ||
|
|
1b54e58726 | ||
|
|
219d7c9611 | ||
|
|
9f742a669e | ||
|
|
41e324fd51 | ||
|
|
ce55a96125 | ||
|
|
64e60a7fde | ||
|
|
972f03960a | ||
|
|
5a403f087d | ||
|
|
fe59d7f3b0 | ||
|
|
b2b2b73aed | ||
|
|
20b563c4cb | ||
|
|
263a0ef5b4 | ||
|
|
e8723b7cd3 | ||
|
|
03e05b2068 | ||
|
|
6c0482a71d | ||
|
|
e6153e6fa4 | ||
|
|
6d209c6cc3 | ||
|
|
beb4e823dc | ||
|
|
61ba4c606b | ||
|
|
af840cedf3 | ||
|
|
0bf0bca03f | ||
|
|
e470eaf8f3 | ||
|
|
377db3f726 | ||
|
|
77f020a997 | ||
|
|
34e2eda625 | ||
|
|
e1d559db69 | ||
|
|
23a98e2ed6 | ||
|
|
fe3b2ed357 | ||
|
|
eedf81dcc5 | ||
|
|
dbef1a9e06 | ||
|
|
a41406ca9a | ||
|
|
f126a61f66 | ||
|
|
89c79276f3 | ||
|
|
423e463b95 | ||
|
|
52202e45de | ||
|
|
100832c66d | ||
|
|
a58b91b221 | ||
|
|
3af6d79852 | ||
|
|
1303e18e93 | ||
|
|
301da97670 | ||
|
|
17e76981bb | ||
|
|
9c1732e2bb | ||
|
|
a3179e7a3f | ||
|
|
f86b50d18a | ||
|
|
307885f505 | ||
|
|
4b49c1dd6b | ||
|
|
f917cefa84 | ||
|
|
bea98438fc | ||
|
|
17d3275086 | ||
|
|
059b7a0fcf | ||
|
|
05d3a989f6 | ||
|
|
590ae70c12 | ||
|
|
5240ec6e6f | ||
|
|
04772b642c | ||
|
|
65f6cb416f | ||
|
|
24c2028739 | ||
|
|
b0db9a3f56 | ||
|
|
3ea83574c0 | ||
|
|
05252a9bfc | ||
|
|
ce854f086e | ||
|
|
ff0c16978c | ||
|
|
41cc650031 | ||
|
|
c3f7554053 | ||
|
|
3f597a1c60 | ||
|
|
ccffdf1878 | ||
|
|
474089e892 | ||
|
|
778e8ad161 | ||
|
|
9f29892c24 | ||
|
|
56fd46a069 | ||
|
|
579e594861 | ||
|
|
af3440fbe3 | ||
|
|
cc101f55c4 | ||
|
|
ef1adf07f5 | ||
|
|
625c05d9be | ||
|
|
8ad3d8f738 | ||
|
|
4759875733 | ||
|
|
768e6a3c55 | ||
|
|
45bd85c039 | ||
|
|
9f94c5a8bd | ||
|
|
23fdd65961 | ||
|
|
8034195c30 | ||
|
|
08761127c9 | ||
|
|
4a10010b6c | ||
|
|
14cc5e2453 | ||
|
|
3d87adea60 | ||
|
|
36e8232ab6 | ||
|
|
72722a73be | ||
|
|
a09aa232a9 | ||
|
|
7ae8b64699 | ||
|
|
60e0d17f34 | ||
|
|
bf8bef2f00 | ||
|
|
b586d67bac | ||
|
|
31e5e5af13 | ||
|
|
94871e88cd | ||
|
|
00e56d1968 | ||
|
|
43672a53ab | ||
|
|
45097ed2a6 | ||
|
|
871f6b9f95 | ||
|
|
e6476e3c75 | ||
|
|
ac9b5f246d | ||
|
|
8bc72a2744 | ||
|
|
f76f1d89d7 | ||
|
|
7b54762b5e | ||
|
|
bc6faf6a6d | ||
|
|
e7ae1ac9b2 | ||
|
|
dcb436adb1 | ||
|
|
80f0441905 | ||
|
|
8cde803654 | ||
|
|
62445680ad | ||
|
|
7685e36886 | ||
|
|
4c196844bd | ||
|
|
b36159bda4 | ||
|
|
b02948d49a | ||
|
|
f442d206be | ||
|
|
21ed6bccd8 | ||
|
|
143ce7f00b | ||
|
|
28e716139b | ||
|
|
80a7c0c521 | ||
|
|
255ad3d2ad | ||
|
|
089bc9c7d8 | ||
|
|
ee7dafaf57 | ||
|
|
516ecdb0ee | ||
|
|
b77675f74d | ||
|
|
eea5c8efad | ||
|
|
09f1aac3a3 | ||
|
|
dd1dcb5eba | ||
|
|
757bd62ebe | ||
|
|
5a3127949b | ||
|
|
ced934c0a3 | ||
|
|
c32445084f | ||
|
|
9f1af0cdaa | ||
|
|
0d26cab400 | ||
|
|
c8de2da3fc | ||
|
|
ca089a105e | ||
|
|
22000918d6 | ||
|
|
6affc28da4 | ||
|
|
f659995e1c | ||
|
|
56fb3e738f | ||
|
|
56d450a907 | ||
|
|
d3cdcef36b | ||
|
|
19434e73b4 | ||
|
|
f7b3df9583 | ||
|
|
4da4b3bd50 | ||
|
|
e83513882a | ||
|
|
5adc784b6b | ||
|
|
f177513523 | ||
|
|
8ebcf79b1a | ||
|
|
c7e5f24704 | ||
|
|
ab3eb32ec8 | ||
|
|
d76509e5cb | ||
|
|
04f56aab82 | ||
|
|
c7913cbbbb | ||
|
|
0556468518 | ||
|
|
1c7ef827b6 | ||
|
|
5720ed4d64 | ||
|
|
7f05af4a68 | ||
|
|
6db615ed5a | ||
|
|
465f020c86 | ||
|
|
f05b77088f | ||
|
|
80a5abf1ad | ||
|
|
7a6e8de60f | ||
|
|
8364fa74cf | ||
|
|
14f4566dd0 | ||
|
|
6145378923 | ||
|
|
68e2606427 | ||
|
|
0f3eb04d1a | ||
|
|
4a355323b2 | ||
|
|
8601fbb4ea | ||
|
|
db885aa180 | ||
|
|
c18fb980a2 | ||
|
|
b630dbdf20 | ||
|
|
29ac1b5e01 | ||
|
|
506d3b079e | ||
|
|
0670e6b53a | ||
|
|
76124ea35b | ||
|
|
6eae3470cd | ||
|
|
c7ba7ac876 | ||
|
|
edc733abd9 | ||
|
|
a56ded664e | ||
|
|
31ace5fb0c | ||
|
|
11010236b3 | ||
|
|
5f061ac1e2 | ||
|
|
72919fa34e | ||
|
|
d5ca99fc3c | ||
|
|
e49b72ee4e | ||
|
|
abe8db8154 | ||
|
|
e0e5941384 | ||
|
|
86e1f4e8b0 | ||
|
|
447d873ef0 | ||
|
|
b21d613ce4 | ||
|
|
fc91adb32f | ||
|
|
71885db5fd | ||
|
|
b88d14b3df | ||
|
|
d98d35a8a8 | ||
|
|
87bc0ebd73 | ||
|
|
7b6ba3f690 | ||
|
|
b0d8948428 | ||
|
|
b32d681cee | ||
|
|
11a66d1d09 | ||
|
|
e41987f08c | ||
|
|
34b57ec188 | ||
|
|
d74843be31 | ||
|
|
1216c6f9c9 | ||
|
|
865b6017d3 | ||
|
|
922a021821 | ||
|
|
0b5f4cac57 | ||
|
|
c988c58c63 | ||
|
|
ceb8cbf59e | ||
|
|
52e9f43c46 | ||
|
|
4e5e7761fc | ||
|
|
9879999a65 | ||
|
|
bedaca70a3 | ||
|
|
2dd2225d2e | ||
|
|
d82031eec1 | ||
|
|
e5f2860b74 | ||
|
|
fa3560bb61 | ||
|
|
9b23f6ce30 | ||
|
|
5d6aa6cfd5 | ||
|
|
7d1819335f | ||
|
|
539e7a3f2d | ||
|
|
1686924ac8 | ||
|
|
556c1dc67b | ||
|
|
00f7093e65 | ||
|
|
79eb11dce9 | ||
|
|
0bf48c0d41 | ||
|
|
3f33e5f770 | ||
|
|
da3888ba9e | ||
|
|
a2f91b1055 | ||
|
|
d26095dfa1 | ||
|
|
83e786bd1e | ||
|
|
4cae12a507 | ||
|
|
d8e3708e0f | ||
|
|
f4de2fd3b1 | ||
|
|
e1cb30bbb4 | ||
|
|
97e0edc549 | ||
|
|
f4e66bf14f | ||
|
|
a6a7fe8aba | ||
|
|
a273f72560 | ||
|
|
b5126f45d6 | ||
|
|
ba3bb7cbf3 | ||
|
|
608279487b | ||
|
|
72b5374916 | ||
|
|
08b03212ca | ||
|
|
7e341a05a1 | ||
|
|
e665d08ee1 | ||
|
|
ba6362dc9d | ||
|
|
48f0797c43 | ||
|
|
640b0c4939 | ||
|
|
287c61e277 | ||
|
|
f7b2516109 | ||
|
|
b530eb49d4 | ||
|
|
fa94979ab6 | ||
|
|
54f2acf5b9 | ||
|
|
b6d845a4d0 | ||
|
|
1095b7c37f | ||
|
|
136ffd97ca | ||
|
|
80163d0af2 | ||
|
|
e1c6e926e7 | ||
|
|
2bb74abf31 | ||
|
|
0d4b91afe0 | ||
|
|
6c688d6878 | ||
|
|
243feecef9 | ||
|
|
abd22ba087 | ||
|
|
ab25546e97 | ||
|
|
925f0fca2a | ||
|
|
066366d885 | ||
|
|
61d52e96b7 | ||
|
|
051e88ca90 | ||
|
|
e873b69850 | ||
|
|
661fd55556 | ||
|
|
402f5a4717 | ||
|
|
81bf52ef37 | ||
|
|
8ff92796df | ||
|
|
68af60e12e | ||
|
|
cce6bf9428 | ||
|
|
078908fbea | ||
|
|
7275caaf5b | ||
|
|
d9487c1df4 | ||
|
|
3a9f955388 | ||
|
|
e46c7acd2e | ||
|
|
b771664851 | ||
|
|
7c21819d20 | ||
|
|
a57e618d47 | ||
|
|
c9849a79ea | ||
|
|
f1643fec08 | ||
|
|
951e63ca87 | ||
|
|
8e539c8a8c | ||
|
|
1e689a4902 | ||
|
|
7bbd25b5ec | ||
|
|
b1c7236117 | ||
|
|
ae3e473024 | ||
|
|
fd616f247c | ||
|
|
45dca2c821 | ||
|
|
40dc108c84 | ||
|
|
a421c25952 | ||
|
|
562d0afdbb | ||
|
|
2ce4698eef | ||
|
|
cb53108041 | ||
|
|
5fa65e5cc6 | ||
|
|
e8b0b6cef5 | ||
|
|
eca2712828 | ||
|
|
2804c0aede | ||
|
|
0429f0480d | ||
|
|
024759a0fc | ||
|
|
9a94aef2b0 | ||
|
|
e329cb45cd | ||
|
|
0dc38bd684 | ||
|
|
98ebca5f8c | ||
|
|
05cb3e03cf | ||
|
|
181132c149 | ||
|
|
a69aa00155 | ||
|
|
47d415e31c | ||
|
|
667a156817 | ||
|
|
00f39b977e | ||
|
|
e5776e2bd6 | ||
|
|
2b21f54897 | ||
|
|
678d12fcd5 | ||
|
|
03f06f611e | ||
|
|
6571e0f814 | ||
|
|
44f91026e1 | ||
|
|
56237328f1 | ||
|
|
ff68901e89 | ||
|
|
e0e7adb2b2 | ||
|
|
0923a5b128 | ||
|
|
75f8a84c79 | ||
|
|
af815cf7eb | ||
|
|
ef4d6c26f6 | ||
|
|
5087b306c0 | ||
|
|
a5708eaefe | ||
|
|
389bfc9e31 | ||
|
|
fd269e91e0 | ||
|
|
80136b0dfc | ||
|
|
9595eff1f9 | ||
|
|
c3c95754f7 | ||
|
|
22ab63fe8d | ||
|
|
5fefcab475 | ||
|
|
771a05b894 | ||
|
|
e2d8aaa923 | ||
|
|
0951aecb13 | ||
|
|
b1fe6f9853 | ||
|
|
551dd393aa | ||
|
|
78b4562184 | ||
|
|
c49b90e621 | ||
|
|
89e6233fbf | ||
|
|
3f9496c237 | ||
|
|
36e94af598 | ||
|
|
a181a684f5 | ||
|
|
bb712b3b3f | ||
|
|
e795de5647 | ||
|
|
bdc428cdd8 | ||
|
|
e4376e21dd | ||
|
|
77acc7baed | ||
|
|
9db1556c4d | ||
|
|
65de8b329b | ||
|
|
08dae5b047 | ||
|
|
8d2f056407 | ||
|
|
e66ef2e25e | ||
|
|
4d3ee7e082 | ||
|
|
fe48fda2f3 | ||
|
|
0f66753aa1 | ||
|
|
a18878474b | ||
|
|
0aa4568fd4 | ||
|
|
1de7e5760a | ||
|
|
135d6f2763 | ||
|
|
061767ede3 | ||
|
|
7204844bcb | ||
|
|
f2279ecadd | ||
|
|
75694869d2 | ||
|
|
d029680ac1 | ||
|
|
41c195d936 | ||
|
|
03ea005e9c | ||
|
|
6d936a7c44 | ||
|
|
fba17b93a6 | ||
|
|
73a7a27ea1 | ||
|
|
79287c2d16 | ||
|
|
662c5f4b77 | ||
|
|
7728ca6843 | ||
|
|
9607372f89 | ||
|
|
d27f948b78 | ||
|
|
b7aab81717 | ||
|
|
2998287f61 | ||
|
|
55d7f0ff5b | ||
|
|
4564f36d4a | ||
|
|
319de5c4e9 | ||
|
|
eee499faa3 | ||
|
|
63c5e42f2a | ||
|
|
bd16dc4479 | ||
|
|
49371ddec9 | ||
|
|
6a10d31b19 | ||
|
|
c951e733d3 | ||
|
|
7ed24cf847 | ||
|
|
821b7a0435 | ||
|
|
1b0344c412 | ||
|
|
03ca3c4b3d | ||
|
|
b939192b16 | ||
|
|
7ccf559a06 | ||
|
|
9eb091f873 | ||
|
|
3bd5521641 | ||
|
|
ced748e419 | ||
|
|
fbd137da9f | ||
|
|
03baebced6 | ||
|
|
cb19c1c370 | ||
|
|
788bad61d0 | ||
|
|
8f5f9bd44e | ||
|
|
2873e3e084 | ||
|
|
b004f17ae3 | ||
|
|
bea1e8c99b | ||
|
|
111493223f | ||
|
|
0a5ac2baec | ||
|
|
eec3c3b884 | ||
|
|
07b72c3d70 | ||
|
|
766e8c4eb0 | ||
|
|
57c257d10d | ||
|
|
d497da0e61 | ||
|
|
62310e7929 | ||
|
|
d79aa173a6 | ||
|
|
fbfdd3e003 | ||
|
|
a62b4a26ef | ||
|
|
817d4168c6 | ||
|
|
7e0a6d1538 | ||
|
|
ebc498ad19 | ||
|
|
b97b8c6ce6 | ||
|
|
b8abff65a1 | ||
|
|
a953dc1dbd | ||
|
|
a7c9848e99 | ||
|
|
73a1449eaf | ||
|
|
59f57ff542 | ||
|
|
e9204b87e3 | ||
|
|
7dd11bd60a | ||
|
|
275fc2ccf9 | ||
|
|
a2ef8d9d47 | ||
|
|
196779ff19 | ||
|
|
aee3147365 | ||
|
|
eaca940956 | ||
|
|
06006733e2 | ||
|
|
14d0bfbef6 | ||
|
|
0c9cf73702 | ||
|
|
3b864921ac | ||
|
|
f41539532f | ||
|
|
657009c254 | ||
|
|
c47e02c309 | ||
|
|
ce8a7bc178 | ||
|
|
488ca87787 | ||
|
|
d965df8ca9 | ||
|
|
995c26751e | ||
|
|
dd09723a2a | ||
|
|
5ff5af3ba2 | ||
|
|
4cb85404c0 | ||
|
|
50bc2f100d | ||
|
|
f65ce6a019 | ||
|
|
c28b635f2d | ||
|
|
e55896240d | ||
|
|
2b478ee7e1 | ||
|
|
69912a35ea | ||
|
|
9f1bd98c7e | ||
|
|
b531d6b7f0 | ||
|
|
8aa963fb81 | ||
|
|
b76e0ab4e4 | ||
|
|
aea03b4e92 | ||
|
|
b39e95966c | ||
|
|
d53e5e0158 | ||
|
|
0368dd651b | ||
|
|
84a4a1024e | ||
|
|
af4f258489 | ||
|
|
ddfc8785b4 | ||
|
|
d8515b6efc | ||
|
|
6a07f007a4 | ||
|
|
7a5a0c8075 | ||
|
|
5ed2e9b0fc | ||
|
|
aeb0a45eb6 | ||
|
|
21e814d766 | ||
|
|
cafc1839e2 | ||
|
|
e937aa831f | ||
|
|
890e6a95ed | ||
|
|
a5b7274359 | ||
|
|
172acf2cf5 | ||
|
|
b49fdf6407 | ||
|
|
5184d05bc2 | ||
|
|
7ef4553fc9 | ||
|
|
d6bd1e4a49 | ||
|
|
29413f20a7 | ||
|
|
04a44c8ea7 | ||
|
|
426f1b6f9a | ||
|
|
9c7f5ed321 | ||
|
|
4c37c7f280 | ||
|
|
a2d13cacbf | ||
|
|
aa127b83a3 | ||
|
|
e55192ae2a | ||
|
|
5159fcbc33 | ||
|
|
02ad7a0f93 | ||
|
|
bfa496e37f | ||
|
|
fdf347af26 | ||
|
|
0833dbb19d | ||
|
|
1b6bf58e58 | ||
|
|
5ead7bc7b4 | ||
|
|
f326d17856 | ||
|
|
908aa9beea | ||
|
|
4071e96245 | ||
|
|
b4daf29bd8 | ||
|
|
bf185339c2 | ||
|
|
df3abc75c2 | ||
|
|
28fc9a387c | ||
|
|
8533f207dc | ||
|
|
d135c48319 | ||
|
|
ca9090d070 | ||
|
|
93b185dc3b | ||
|
|
98e5efa895 | ||
|
|
c6774b829d | ||
|
|
22925f92bd | ||
|
|
302efcf6e8 | ||
|
|
76f9f90f0a | ||
|
|
5ba338e471 | ||
|
|
01f101c6f2 | ||
|
|
5606aec78d | ||
|
|
db90e1fe8b | ||
|
|
ae96c479f2 | ||
|
|
344ed2c83e | ||
|
|
1985944659 | ||
|
|
915357a6c1 | ||
|
|
63c34e78d7 | ||
|
|
366c460c1f | ||
|
|
40cab08133 | ||
|
|
51de25122a | ||
|
|
90313091db | ||
|
|
9982219d18 | ||
|
|
b3fe03b8f9 | ||
|
|
6edd15d68a | ||
|
|
0e2b328c88 | ||
|
|
25d7f9c316 | ||
|
|
3870ebdf29 | ||
|
|
7595d05191 | ||
|
|
21af727d79 | ||
|
|
5691829de6 | ||
|
|
20e6a57cf1 | ||
|
|
d0c40a8b5b | ||
|
|
f663215f25 | ||
|
|
7c5dea6d12 | ||
|
|
87261bdbc9 | ||
|
|
4e4b6c6dbc | ||
|
|
5e8cf9fb6a | ||
|
|
c738fe051f | ||
|
|
29fe1533f2 | ||
|
|
77090070bd | ||
|
|
6ba9b1b6b0 | ||
|
|
c578b8df1e | ||
|
|
cad9a41433 | ||
|
|
5fefb3b0f4 | ||
|
|
5284a870b0 | ||
|
|
e064377c05 | ||
|
|
3e569c8312 | ||
|
|
16825ee6e9 | ||
|
|
3f5340fa53 | ||
|
|
f2a1a39b33 | ||
|
|
326de55d3e | ||
|
|
b2df909570 | ||
|
|
026ac36b06 | ||
|
|
92125e5fd2 | ||
|
|
c0c139da88 | ||
|
|
404ad6a7fd | ||
|
|
fc39086fb4 | ||
|
|
cd215700fe | ||
|
|
e97fd85904 | ||
|
|
0a263fa5b1 | ||
|
|
fae3836a8d | ||
|
|
b3d2eb4178 | ||
|
|
576f1cbb75 | ||
|
|
50085b40bb | ||
|
|
cff382715a | ||
|
|
54d54d1bf2 | ||
|
|
e84ea68282 | ||
|
|
160dd36782 | ||
|
|
65bb46bcca | ||
|
|
2d185fb766 | ||
|
|
2ba9b02932 | ||
|
|
849da67cc7 | ||
|
|
3ea6c9666e | ||
|
|
cf633e4ef2 | ||
|
|
bbf934d980 | ||
|
|
620f733110 | ||
|
|
67928609a3 | ||
|
|
5f15afb7db | ||
|
|
635d2f480d | ||
|
|
70c278c810 | ||
|
|
56b9906e2e | ||
|
|
a808ce81fd | ||
|
|
83f82c5ddf | ||
|
|
101de8c25d | ||
|
|
3339a4baf0 | ||
|
|
dff4a88baa | ||
|
|
a21f6c4964 | ||
|
|
97562504b7 | ||
|
|
75d8ac378c | ||
|
|
b9dd354e2b | ||
|
|
33c2fbd201 | ||
|
|
5063be92bf | ||
|
|
1047584b3e | ||
|
|
6764dcfdaa | ||
|
|
012864ceb1 | ||
|
|
a0bf20bcee | ||
|
|
14ab339b33 | ||
|
|
25c91efbb6 | ||
|
|
1c1f2c6664 | ||
|
|
d7c22b3bf7 | ||
|
|
185f2a395f | ||
|
|
0c5649491e | ||
|
|
94aba5892a | ||
|
|
ef093dde29 | ||
|
|
34451e5f27 | ||
|
|
1f9bdd1a9a | ||
|
|
c27d59baf7 | ||
|
|
f130ddec7c | ||
|
|
a0a259eef1 | ||
|
|
b66f19d4d1 | ||
|
|
4105a78b83 | ||
|
|
19a68afb3a | ||
|
|
fd68a2475b | ||
|
|
28ff7ba830 | ||
|
|
5d0b248fdb | ||
|
|
01a4e0f6ef | ||
|
|
91e0731506 | ||
|
|
d1f904d41f | ||
|
|
269388c9f4 | ||
|
|
b8486379ce | ||
|
|
400eb94d3b | ||
|
|
e210c96485 | ||
|
|
5f567f41f4 | ||
|
|
5fed573a29 | ||
|
|
cfac7c8189 | ||
|
|
1787de6836 | ||
|
|
ac96f187bd | ||
|
|
72398350b4 | ||
|
|
df9445c351 | ||
|
|
87b7a2e39b | ||
|
|
f7e46622a1 | ||
|
|
71f18353a9 | ||
|
|
4228de707b | ||
|
|
b6a05629ef | ||
|
|
fbaa820643 | ||
|
|
db2a2d5e38 | ||
|
|
8ba6e6b1f8 | ||
|
|
57168d719b | ||
|
|
dee6d2c98e | ||
|
|
e49105ece5 | ||
|
|
0c5e11f521 | ||
|
|
a63f842a13 | ||
|
|
4bd7fda694 | ||
|
|
81f0886d6f | ||
|
|
2eb87f3306 | ||
|
|
723f3ab0a9 | ||
|
|
1bd90e0fd4 | ||
|
|
436f18ff55 | ||
|
|
cde9696214 | ||
|
|
2d9042fb93 | ||
|
|
9ed53af520 | ||
|
|
56fda669fd | ||
|
|
1d8545a76c | ||
|
|
5f59a828f9 | ||
|
|
1fa6bddc89 | ||
|
|
d3a5ca5247 | ||
|
|
f01f56a98e | ||
|
|
99b0f79784 | ||
|
|
e1eb104345 | ||
|
|
5c2f95ef50 | ||
|
|
b63df9bab9 | ||
|
|
a52c899c6d | ||
|
|
eeabb7ebe5 | ||
|
|
8b1cef978c | ||
|
|
152da482cd | ||
|
|
3cf0365a35 | ||
|
|
5870742bb9 | ||
|
|
01d8c62c57 | ||
|
|
55a242b2d6 | ||
|
|
45263b339f | ||
|
|
3319491861 | ||
|
|
e687afac90 | ||
|
|
b39031ea53 | ||
|
|
0b77511271 | ||
|
|
c99cd989c1 | ||
|
|
317fdadb21 | ||
|
|
4e294f9e3e | ||
|
|
526e0f30a0 | ||
|
|
231e5ec94a | ||
|
|
e5bb6f9693 | ||
|
|
da7dee44c6 | ||
|
|
83144f4fe3 | ||
|
|
c451f52ea3 | ||
|
|
8a2c78f2e1 | ||
|
|
bcc78bde9b | ||
|
|
054bb6fe0a | ||
|
|
4f4aa6d92e | ||
|
|
eac51ac6f5 | ||
|
|
9f349a7c0a | ||
|
|
918afa5b15 | ||
|
|
eb1113f95c | ||
|
|
4f4ba7b462 | ||
|
|
2298be0e6b | ||
|
|
63494dfca7 | ||
|
|
36a1d39454 | ||
|
|
a6f6d5c400 | ||
|
|
e85f221aca | ||
|
|
d4797e37dc | ||
|
|
3e7923d072 | ||
|
|
a85d69ce3d | ||
|
|
96db006c99 | ||
|
|
8ca57d03d8 | ||
|
|
6c404ce5f8 | ||
|
|
584e07182b | ||
|
|
f787e9acf6 | ||
|
|
5a24b89e54 | ||
|
|
9b482e2a4f | ||
|
|
df4dbe2d57 | ||
|
|
713bd11177 | ||
|
|
182571df4b | ||
|
|
29bfe492b6 | ||
|
|
3fb4e3050c | ||
|
|
39c7ec3cd9 | ||
|
|
26bfbdec7f | ||
|
|
7a3eaa8da9 | ||
|
|
599db7296f | ||
|
|
042aab4295 | ||
|
|
24f298283f | ||
|
|
68dac6349d | ||
|
|
b675fc19e8 | ||
|
|
659019cfd6 | ||
|
|
dcd61e1f82 | ||
|
|
f5c99b1488 | ||
|
|
810be3e1d4 | ||
|
|
60d754d1df | ||
|
|
bd07c86db9 | ||
|
|
bcbf8b6bd8 | ||
|
|
356661459b | ||
|
|
deb917825e | ||
|
|
15415c6d85 | ||
|
|
76b0380b5f | ||
|
|
2d58754789 | ||
|
|
9cdf1f599c | ||
|
|
268be97ba0 | ||
|
|
a9014673a0 | ||
|
|
d36c43a10f | ||
|
|
54a5c4e482 | ||
|
|
5e09a244e3 | ||
|
|
88648dca1a | ||
|
|
8840df2b00 | ||
|
|
af159acbdf | ||
|
|
471719bbbe | ||
|
|
b126f2ffd5 | ||
|
|
9938f12ef0 | ||
|
|
982c266073 | ||
|
|
5c37391883 | ||
|
|
ddeafc6833 | ||
|
|
41b2d5d013 | ||
|
|
29d6f48901 | ||
|
|
d5c9f4e47f | ||
|
|
24d73387d8 | ||
|
|
e0d3927265 | ||
|
|
e5f7c2a9b7 | ||
|
|
b0760710d5 | ||
|
|
764accc921 | ||
|
|
6a01fce9c1 | ||
|
|
9c732ac3b1 | ||
|
|
b70891c661 | ||
|
|
4dbf851741 | ||
|
|
6c927a9fd4 | ||
|
|
096f001634 | ||
|
|
4837e578b2 | ||
|
|
1e547ef912 | ||
|
|
f6b8970bd1 | ||
|
|
29325a7214 | ||
|
|
8ecf72838d | ||
|
|
c3ab8a6aa8 | ||
|
|
1931aa3e70 | ||
|
|
d3d8055055 | ||
|
|
476b0a0403 | ||
|
|
f66584713c | ||
|
|
33624fc2fa | ||
|
|
41c3e73a3c | ||
|
|
97553a7de2 | ||
|
|
12ba15bfa9 | ||
|
|
09d1e190e7 | ||
|
|
8eb5d08499 | ||
|
|
9be6acde7d | ||
|
|
5f83bb0069 | ||
|
|
b138882abc | ||
|
|
0cd7cdb52e | ||
|
|
1d8b7e2bcf | ||
|
|
6461f4758d | ||
|
|
3189ab6863 | ||
|
|
3f9a674d4b | ||
|
|
587f59b25b | ||
|
|
4952eada87 | ||
|
|
581029ebaa | ||
|
|
42d68780de | ||
|
|
28032a2f80 | ||
|
|
e381e021e9 | ||
|
|
641af64f93 | ||
|
|
a7b83c8b5b | ||
|
|
4cc41e0188 | ||
|
|
442fc02429 | ||
|
|
9a4d075074 | ||
|
|
17ff8196cb | ||
|
|
68f993998a | ||
|
|
7da6120b39 | ||
|
|
6cd40965c4 | ||
|
|
408a1d6dbb | ||
|
|
0b0abfbe8f | ||
|
|
cc96dcf0ed | ||
|
|
2604fd9fde | ||
|
|
140670d00e | ||
|
|
70233fae5d | ||
|
|
6f457a6c4c | ||
|
|
5c319f5356 | ||
|
|
991a04f090 | ||
|
|
c39fa75113 | ||
|
|
f7863e17ce | ||
|
|
7c526390ed | ||
|
|
2cff20f87a | ||
|
|
90ec757802 | ||
|
|
4b85dfcefe | ||
|
|
21deefdc41 | ||
|
|
857d74bbfe | ||
|
|
fd7a635777 | ||
|
|
af9110e964 | ||
|
|
a61209206b | ||
|
|
e05cc62e5f | ||
|
|
4d4f921a4e | ||
|
|
98db8f395b | ||
|
|
f465a956a3 | ||
|
|
9edb02d7ef | ||
|
|
6c4cf58a31 | ||
|
|
08993c0d29 | ||
|
|
4f8a4b0f22 | ||
|
|
a743f3c9b5 | ||
|
|
217fe40d99 | ||
|
|
b76bf50b93 | ||
|
|
571ba87e13 | ||
|
|
f27b6e2b44 | ||
|
|
981475a624 | ||
|
|
27ac61a4fb | ||
|
|
675ffc2757 | ||
|
|
44b21f10f1 | ||
|
|
c6d49e8b1f | ||
|
|
e6a512aa86 | ||
|
|
c3a6a6fb22 | ||
|
|
b9dc3460ba | ||
|
|
63581ec980 | ||
|
|
08b1feeed7 | ||
|
|
f5cfdcf32d | ||
|
|
e78fb428f0 | ||
|
|
31e270e32c | ||
|
|
b5832768dc | ||
|
|
4ce64b69cb | ||
|
|
5a9173f766 | ||
|
|
0bb7ed44f6 | ||
|
|
332bc9da5b | ||
|
|
08def3da95 | ||
|
|
daf899f9c4 | ||
|
|
13fb2d1f49 | ||
|
|
95dde802ea | ||
|
|
fca119773b | ||
|
|
0193267a53 | ||
|
|
b4cf78a95d | ||
|
|
73386826d6 | ||
|
|
9f448fecb7 | ||
|
|
bcd1483a14 | ||
|
|
e206890e25 | ||
|
|
0a7048f650 | ||
|
|
e8ecf5e155 | ||
|
|
33e8604b57 | ||
|
|
cec7399366 | ||
|
|
bdae81e429 | ||
|
|
67c32f3d6c | ||
|
|
94d64b8a78 | ||
|
|
fa3c0c81b3 | ||
|
|
66547b99c1 | ||
|
|
328e58be4c | ||
|
|
18f89ed5ed | ||
|
|
5701c79fab | ||
|
|
2da9f913f3 | ||
|
|
6b10b59abe | ||
|
|
918f77bce0 | ||
|
|
f170697ebe | ||
|
|
556c6a1d84 | ||
|
|
aca2a2fa13 | ||
|
|
ff6398f7d8 | ||
|
|
cf996472b9 | ||
|
|
156d14c349 | ||
|
|
86f705bf48 | ||
|
|
1fd9631f2d | ||
|
|
2227a2357f | ||
|
|
58e7ab157d | ||
|
|
8d16fa6a49 | ||
|
|
55e810efa3 | ||
|
|
2755316021 | ||
|
|
6525f18610 | ||
|
|
2ad13ac7eb | ||
|
|
693a3eaff5 | ||
|
|
ffca792d5b | ||
|
|
86a92bb6b5 | ||
|
|
171a4e6d80 | ||
|
|
e3a75a8adf | ||
|
|
ee7503ce13 | ||
|
|
8500bac3ca | ||
|
|
310719eb4c | ||
|
|
e8e24822ec | ||
|
|
c57a7afb87 | ||
|
|
84d028898c | ||
|
|
ed0174fbc6 | ||
|
|
9e582563eb | ||
|
|
faa88f72bf | ||
|
|
0d69a31df0 | ||
|
|
daa5a88eb2 | ||
|
|
5b84e117b2 | ||
|
|
eb257d2d28 | ||
|
|
5810cee6c9 | ||
|
|
eef88d1f83 | ||
|
|
78f6850fc0 | ||
|
|
bd8890be11 | ||
|
|
adf1a977ea | ||
|
|
e1509bcb45 | ||
|
|
edcaf8287d | ||
|
|
39bd30f2a0 | ||
|
|
102b47190f | ||
|
|
269fe2e3bb | ||
|
|
b32aa1c77f | ||
|
|
6656544ed5 | ||
|
|
4c75b93410 | ||
|
|
5be0de967d | ||
|
|
f8e27b837b | ||
|
|
47414be1e6 | ||
|
|
74cef38bcf | ||
|
|
bb876b8d4e | ||
|
|
e5d9ca013e | ||
|
|
4166c756ce | ||
|
|
4f0dfbd34d | ||
|
|
46c632e7cc | ||
|
|
653f63ae71 | ||
|
|
8a9e2f57a4 | ||
|
|
31949ed2f2 | ||
|
|
0ccb304b8b | ||
|
|
ab0bfa709a | ||
|
|
6af659b1da | ||
|
|
416d29fb83 | ||
|
|
19c00241c6 | ||
|
|
c323a760a5 | ||
|
|
9d1fcba415 | ||
|
|
ca21996a97 | ||
|
|
62aa064e56 | ||
|
|
87eb018380 | ||
|
|
5003e5d763 | ||
|
|
58f3072b91 | ||
|
|
9e7b470189 |
2
.github/workflows/python-checks.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff
|
||||
run: pip install ruff==0.6.0
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
|
||||
2
.github/workflows/python-tests.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
os: macOS-14
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aeb428d0-0817-462c-b5d8-455a0615d305",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from PIL import Image\n",
|
||||
"import numpy as np\n",
|
||||
"import cv2\n",
|
||||
"\n",
|
||||
"from invokeai.backend.vto_workflow.overlay_pattern import generate_dress_mask, multiply_images\n",
|
||||
"from invokeai.backend.vto_workflow.extract_channel import extract_channel, ImageChannel\n",
|
||||
"from invokeai.backend.vto_workflow.seamless_mapping import map_seamless_tiles\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6140d4b7-8238-431c-848e-6f6ae27652f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" # Load the model image.\n",
|
||||
"model_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg\")\n",
|
||||
"\n",
|
||||
"# Load the pattern image.\n",
|
||||
"pattern_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fb7186ba-dc0c-4520-ac30-49073a65601a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mask = generate_dress_mask(model_image)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b935de4-94c5-4be5-bf8e-a5a6e445c811",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize mask\n",
|
||||
"model_image_np = np.array(model_image)\n",
|
||||
"masked_model_image = (model_image_np * np.expand_dims(mask, -1).astype(np.float32)).astype(np.uint8)\n",
|
||||
"mask_image = Image.fromarray(masked_model_image)\n",
|
||||
"mask_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e51bb545",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"shadows = extract_channel(np.array(model_image), ImageChannel.LAB_L)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec43de4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize masked shadows\n",
|
||||
"masked_shadows = (shadows * mask).astype(np.uint8)\n",
|
||||
"masked_shadows_image = Image.fromarray(masked_shadows)\n",
|
||||
"masked_shadows_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dbb53794",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Tile the pattern.\n",
|
||||
"expanded_pattern = map_seamless_tiles(seamless_tile=pattern_image, target_hw=(model_image.height, model_image.width), num_repeats_h=10.0)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f4f22d02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Multiply the pattern by the shadows.\n",
|
||||
"pattern_with_shadows = multiply_images(expanded_pattern, shadows)\n",
|
||||
"pattern_with_shadows"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97db42b0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de32f7e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Merge the pattern with the model image.\n",
|
||||
"pattern_with_shadows_np = np.array(pattern_with_shadows)\n",
|
||||
"merged_image = np.where(mask[:, :, None], pattern_with_shadows_np,model_image_np)\n",
|
||||
"merged_image = Image.fromarray(merged_image)\n",
|
||||
"merged_image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff1d4044",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
# Invoke in Docker
|
||||
|
||||
- Ensure that Docker can use the GPU on your system
|
||||
- This documentation assumes Linux, but should work similarly under Windows with WSL2
|
||||
First things first:
|
||||
|
||||
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
|
||||
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
|
||||
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
||||
|
||||
## Quickstart :lightning:
|
||||
## Quickstart
|
||||
|
||||
No `docker compose`, no persistence, just a simple one-liner using the official images:
|
||||
No `docker compose`, no persistence, single command, using the official images:
|
||||
|
||||
**CUDA:**
|
||||
**CUDA (NVIDIA GPU):**
|
||||
|
||||
```bash
|
||||
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
||||
```
|
||||
|
||||
**ROCm:**
|
||||
**ROCm (AMD GPU):**
|
||||
|
||||
```bash
|
||||
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
||||
@@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
|
||||
|
||||
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
||||
|
||||
> [!TIP]
|
||||
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
|
||||
### Data persistence
|
||||
|
||||
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
|
||||
|
||||
```bash
|
||||
docker run --volume /some/local/path:/invokeai {...etc...}
|
||||
```
|
||||
|
||||
`/some/local/path/invokeai` will contain all your data.
|
||||
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
|
||||
|
||||
## Customize the container
|
||||
|
||||
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
|
||||
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
@@ -38,11 +48,14 @@ cp .env.sample .env
|
||||
|
||||
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
||||
|
||||
>[!TIP]
|
||||
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
|
||||
|
||||
## Docker setup in detail
|
||||
|
||||
#### Linux
|
||||
|
||||
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
||||
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
||||
3. Ensure docker daemon is able to access the GPU.
|
||||
@@ -98,25 +111,7 @@ GPU_DRIVER=cuda
|
||||
|
||||
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||
|
||||
## Even More Customizing!
|
||||
---
|
||||
|
||||
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||
|
||||
### Reconfigure the runtime directory
|
||||
|
||||
Can be used to download additional models from the supported model list
|
||||
|
||||
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
|
||||
|
||||
```yaml
|
||||
command:
|
||||
- invokeai-configure
|
||||
- --yes
|
||||
```
|
||||
|
||||
Or install models:
|
||||
|
||||
```yaml
|
||||
command:
|
||||
- invokeai-model-install
|
||||
```
|
||||
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
@@ -31,6 +32,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@@ -63,7 +66,12 @@ class ApiDependencies:
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||
def initialize(
|
||||
config: InvokeAIAppConfig,
|
||||
event_handler_id: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
logger: Logger = logger,
|
||||
) -> None:
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
|
||||
@@ -74,6 +82,7 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
style_presets_folder = config.style_presets_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
@@ -84,7 +93,7 @@ class ApiDependencies:
|
||||
board_images = BoardImagesService()
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
events = FastAPIEventService(event_handler_id, loop=loop)
|
||||
bulk_download = BulkDownloadService()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
@@ -109,6 +118,8 @@ class ApiDependencies:
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -134,6 +145,8 @@ class ApiDependencies:
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -218,9 +218,8 @@ async def get_image_workflow(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
@images_router.get(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -231,6 +230,18 @@ async def get_image_workflow(
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
@images_router.head(
|
||||
"/i/{image_name}/full",
|
||||
operation_id="get_image_full_head",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> Response:
|
||||
@@ -242,6 +253,7 @@ async def get_image_full(
|
||||
content = f.read()
|
||||
response = Response(content, media_type="image/png")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
@@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
PruneResult,
|
||||
@@ -105,6 +106,19 @@ async def cancel_by_batch_ids(
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_origin",
|
||||
operation_id="cancel_by_origin",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_origin(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
origin: str = Query(description="The origin to cancel all queue items for"),
|
||||
) -> CancelByOriginResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/clear",
|
||||
operation_id="clear",
|
||||
|
||||
274
invokeai/app/api/routers/style_presets.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
InvalidPresetImportDataError,
|
||||
PresetData,
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordWithImage,
|
||||
StylePresetWithoutId,
|
||||
UnsupportedFileTypeError,
|
||||
parse_presets_from_file,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetFormData(BaseModel):
|
||||
name: str = Field(description="Preset name")
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
type: PresetType = Field(description="Preset type")
|
||||
|
||||
|
||||
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="get_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def get_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to get"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Gets a style preset"""
|
||||
try:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
|
||||
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
|
||||
except StylePresetNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Style preset not found")
|
||||
|
||||
|
||||
@style_presets_router.patch(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="update_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def update_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
style_preset_id: str = Path(description="The id of the style preset to update"),
|
||||
data: str = Form(description="The data of the style preset to update"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Updates a style preset"""
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
else:
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
|
||||
|
||||
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
||||
style_preset_id=style_preset_id, changes=changes
|
||||
)
|
||||
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.delete(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="delete_style_preset",
|
||||
)
|
||||
async def delete_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to delete"),
|
||||
) -> None:
|
||||
"""Deletes a style preset"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/",
|
||||
operation_id="create_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def create_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
data: str = Form(description="The data of the style preset to create"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Creates a style preset"""
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
|
||||
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
|
||||
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
|
||||
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/",
|
||||
operation_id="list_style_presets",
|
||||
responses={
|
||||
200: {"model": list[StylePresetRecordWithImage]},
|
||||
},
|
||||
)
|
||||
async def list_style_presets() -> list[StylePresetRecordWithImage]:
|
||||
"""Gets a page of style presets"""
|
||||
style_presets_with_image: list[StylePresetRecordWithImage] = []
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
|
||||
for preset in style_presets:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
|
||||
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
|
||||
style_presets_with_image.append(style_preset_with_image)
|
||||
|
||||
return style_presets_with_image
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}/image",
|
||||
operation_id="get_style_preset_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The style preset image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The style preset image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_style_preset_image(
|
||||
style_preset_id: str = Path(description="The id of the style preset image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=style_preset_id + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/export",
|
||||
operation_id="export_style_presets",
|
||||
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
|
||||
status_code=200,
|
||||
)
|
||||
async def export_style_presets():
|
||||
# Create an in-memory stream to store the CSV data
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write the header
|
||||
writer.writerow(["name", "prompt", "negative_prompt"])
|
||||
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
|
||||
|
||||
for preset in style_presets:
|
||||
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
|
||||
|
||||
csv_data = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return Response(
|
||||
content=csv_data,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
|
||||
)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/import",
|
||||
operation_id="import_style_presets",
|
||||
)
|
||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||
try:
|
||||
style_presets = await parse_presets_from_file(file)
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
except InvalidPresetImportDataError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except UnsupportedFileTypeError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
@@ -30,6 +30,7 @@ from invokeai.app.api.routers import (
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
@@ -55,11 +56,13 @@ mimetypes.add_type("text/css", ".css")
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
@@ -106,6 +109,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
@@ -184,8 +188,6 @@ def invoke_api() -> None:
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import semver
|
||||
@@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
|
||||
version: str = Field(
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
|
||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
if uiconfig.title is not None:
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig.tags is not None:
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig.category is not None:
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig.node_pack is not None:
|
||||
schema["node_pack"] = uiconfig.node_pack
|
||||
schema["classification"] = uiconfig.classification
|
||||
schema["version"] = uiconfig.version
|
||||
if title := model_class.UIConfig.title:
|
||||
schema["title"] = title
|
||||
if tags := model_class.UIConfig.tags:
|
||||
schema["tags"] = tags
|
||||
if category := model_class.UIConfig.category:
|
||||
schema["category"] = category
|
||||
if node_pack := model_class.UIConfig.node_pack:
|
||||
schema["node_pack"] = node_pack
|
||||
schema["classification"] = model_class.UIConfig.classification
|
||||
schema["version"] = model_class.UIConfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
@@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
UIConfig: ClassVar[UIConfigBase]
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
@@ -441,30 +438,25 @@ def invocation(
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
cls.UIConfig.classification = classification
|
||||
|
||||
# Grab the node pack's name from the module name, if it's a custom node
|
||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
||||
if is_custom_node:
|
||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
||||
else:
|
||||
cls.UIConfig.node_pack = None
|
||||
uiconfig: dict[str, Any] = {}
|
||||
uiconfig["title"] = title
|
||||
uiconfig["tags"] = tags
|
||||
uiconfig["category"] = category
|
||||
uiconfig["classification"] = classification
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
uiconfig["version"] = version
|
||||
else:
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
cls.UIConfig.version = "1.0.0"
|
||||
uiconfig["version"] = "1.0.0"
|
||||
|
||||
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
@@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||
@@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
model_state_dict=state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||
|
||||
@@ -21,6 +21,8 @@ from controlnet_aux import (
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
@@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.2",
|
||||
version="1.1.3",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small", description="The size of the depth model to use"
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
@@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
blur_tensor[blur_tensor < 0] = 0.0
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
|
||||
@@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@@ -60,8 +60,13 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@@ -498,6 +503,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_t2i_adapter_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
|
||||
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||
if isinstance(t2i_adapters, T2IAdapterField):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||
weight=t2i_adapter_field.weight,
|
||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -707,7 +739,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
return mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
@@ -765,10 +797,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
@@ -801,21 +829,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
@@ -833,6 +846,50 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
### lora
|
||||
if self.unet.loras:
|
||||
for lora_field in self.unet.loras:
|
||||
ext_manager.add_extension(
|
||||
LoRAExt(
|
||||
node_context=context,
|
||||
model_id=lora_field.lora,
|
||||
weight=lora_field.weight,
|
||||
)
|
||||
)
|
||||
### seamless
|
||||
if self.unet.seamless_axes:
|
||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||
|
||||
### inpaint
|
||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||
if unet_config.variant == ModelVariantType.Inpaint:
|
||||
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||
elif mask is not None:
|
||||
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||
|
||||
# Initialize context for modular denoise
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
# later should be smth like:
|
||||
@@ -840,6 +897,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
@@ -871,6 +929,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||
# We invert the mask here for compatibility with the old backend implementation.
|
||||
if mask is not None:
|
||||
mask = 1 - mask
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
@@ -913,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
@@ -40,14 +40,18 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -125,13 +129,17 @@ class FieldDescriptions:
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -231,6 +239,12 @@ class ColorField(BaseModel):
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
@@ -242,6 +256,31 @@ class ConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxField(BaseModel):
|
||||
"""A bounding box primitive value."""
|
||||
|
||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||
|
||||
score: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
||||
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_coords(self):
|
||||
if self.x_min > self.x_max:
|
||||
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||
if self.y_min > self.y_max:
|
||||
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||
return self
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
|
||||
92
invokeai/app/invocations/flux_text_encoder.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
169
invokeai/app/invocations/flux_text_to_image.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=x.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
100
invokeai/app/invocations/grounding_dino.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
|
||||
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
|
||||
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
|
||||
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"grounding_dino",
|
||||
title="Grounding DINO (Text Prompt Object Detection)",
|
||||
tags=["prompt", "object detection"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GroundingDinoInvocation(BaseInvocation):
|
||||
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
|
||||
|
||||
# Reference:
|
||||
# - https://arxiv.org/pdf/2303.05499
|
||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
|
||||
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
|
||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
detection_threshold: float = InputField(
|
||||
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.3,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
||||
# The model expects a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
detections = self._detect(
|
||||
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||
)
|
||||
|
||||
# Convert detections to BoundingBoxCollectionOutput.
|
||||
bounding_boxes: list[BoundingBoxField] = []
|
||||
for detection in detections:
|
||||
bounding_boxes.append(
|
||||
BoundingBoxField(
|
||||
x_min=detection.box.xmin,
|
||||
x_max=detection.box.xmax,
|
||||
y_min=detection.box.ymin,
|
||||
y_max=detection.box.ymax,
|
||||
score=detection.score,
|
||||
)
|
||||
)
|
||||
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
||||
|
||||
@staticmethod
|
||||
def _load_grounding_dino(model_path: Path):
|
||||
grounding_dino_pipeline = pipeline(
|
||||
model=str(model_path),
|
||||
task="zero-shot-object-detection",
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||
|
||||
def _detect(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
labels: list[str],
|
||||
threshold: float = 0.3,
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||
# actually makes a difference.
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
@@ -6,13 +6,19 @@ import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ColorField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation_output("canvas_v2_mask_and_crop_output")
|
||||
class CanvasV2MaskAndCropOutput(ImageOutput):
|
||||
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
||||
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_v2_mask_and_crop",
|
||||
title="Canvas V2 Mask and Crop",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
|
||||
source_image: ImageField | None = InputField(
|
||||
default=None,
|
||||
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
|
||||
)
|
||||
generated_image: ImageField = InputField(description="The image to apply the mask to")
|
||||
mask: ImageField = InputField(description="The mask to apply")
|
||||
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||
|
||||
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||
mask_array = numpy.array(mask)
|
||||
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||
if self.mask_blur > 0:
|
||||
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
return ImageOps.invert(mask.convert("L"))
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||
|
||||
if self.source_image:
|
||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||
source_image = context.images.get_pil(self.source_image.image_name)
|
||||
source_image.paste(generated_image, (0, 0), mask)
|
||||
image_dto = context.images.save(image=source_image)
|
||||
else:
|
||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||
generated_image.putalpha(mask)
|
||||
image_dto = context.images.save(image=generated_image)
|
||||
|
||||
# bbox = image.getbbox()
|
||||
# image = image.crop(bbox)
|
||||
|
||||
return CanvasV2MaskAndCropOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
offset_x=0,
|
||||
offset_y=0,
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
height=mask.shape[1],
|
||||
width=mask.shape[2],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"tensor_mask_to_image",
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
|
||||
mask: TensorField = InputField(description="The mask tensor to convert.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
mask_np = (mask.float() * 255).byte().cpu().numpy()
|
||||
|
||||
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||
image_dto = context.images.save(image=mask_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@@ -60,6 +67,15 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
@@ -122,6 +138,78 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -7,10 +7,12 @@ import torch
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
@@ -413,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_output")
|
||||
class FluxConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
@@ -469,3 +482,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region BoundingBox
|
||||
|
||||
|
||||
@invocation_output("bounding_box_output")
|
||||
class BoundingBoxOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single bounding box"""
|
||||
|
||||
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
||||
|
||||
|
||||
@invocation_output("bounding_box_collection_output")
|
||||
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of bounding boxes"""
|
||||
|
||||
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
|
||||
|
||||
|
||||
@invocation(
|
||||
"bounding_box",
|
||||
title="Bounding Box",
|
||||
tags=["primitives", "segmentation", "collection", "bounding box"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BoundingBoxInvocation(BaseInvocation):
|
||||
"""Create a bounding box manually by supplying box coordinates"""
|
||||
|
||||
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
|
||||
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
|
||||
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
|
||||
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
|
||||
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
|
||||
return BoundingBoxOutput(bounding_box=bounding_box)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1,76 +1,161 @@
|
||||
from typing import Dict, cast
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.grounding_segment_anything.gsa import GroundingSegmentAnythingDetector
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS = {
|
||||
"groundingdino_swint_ogc": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
||||
"segment_anything_vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
||||
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||
"segment-anything-base": "facebook/sam-vit-base",
|
||||
"segment-anything-large": "facebook/sam-vit-large",
|
||||
"segment-anything-huge": "facebook/sam-vit-huge",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything",
|
||||
title="Segment Anything",
|
||||
tags=["grounding_dino", "segment", "anything"],
|
||||
category="image",
|
||||
tags=["prompt", "segmentation"],
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SegmentAnythingInvocation(BaseInvocation):
|
||||
"""Automatically generate masks from an image using GroundingDINO & Segment Anything"""
|
||||
"""Runs a Segment Anything Model."""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
prompt: str = InputField(default="", description="Keywords to segment", title="Prompt")
|
||||
box_threshold: float = InputField(
|
||||
default=0.5, ge=0, le=1, description="Threshold of box detection", title="Box Threshold"
|
||||
# Reference:
|
||||
# - https://arxiv.org/pdf/2304.02643
|
||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
|
||||
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||
apply_polygon_refinement: bool = InputField(
|
||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||
default=True,
|
||||
)
|
||||
text_threshold: float = InputField(
|
||||
default=0.5, ge=0, le=1, description="Threshold of text detection", title="Text Threshold"
|
||||
)
|
||||
nms_threshold: float = InputField(
|
||||
default=0.8, ge=0, le=1, description="Threshold of nms detection", title="NMS Threshold"
|
||||
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||
default="all",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
input_image = context.images.get_pil(self.image.image_name)
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
grounding_dino_model = context.models.load_remote_model(
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS["groundingdino_swint_ogc"]
|
||||
)
|
||||
segment_anything_model = context.models.load_remote_model(
|
||||
GROUNDING_SEGMENT_ANYTHING_MODELS["segment_anything_vit_h"]
|
||||
if len(self.bounding_boxes) == 0:
|
||||
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||
else:
|
||||
masks = self._segment(context=context, image=image_pil)
|
||||
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||
|
||||
# masks contains bool values, so we merge them via max-reduce.
|
||||
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||
|
||||
mask_tensor_name = context.tensors.save(combined_mask)
|
||||
height, width = combined_mask.shape
|
||||
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(sam_model, SamModel)
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
def _segment(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
# Convert the bounding boxes to the SAM input format.
|
||||
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||
|
||||
with (
|
||||
grounding_dino_model.model_on_device() as (_, grounding_dino_state_dict),
|
||||
segment_anything_model.model_on_device() as (_, segment_anything_state_dict),
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
) as sam_pipeline,
|
||||
):
|
||||
if not grounding_dino_state_dict or not segment_anything_state_dict:
|
||||
raise RuntimeError("Unable to load segmentation models")
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||
|
||||
grounding_dino = GroundingSegmentAnythingDetector.build_grounding_dino(
|
||||
cast(Dict[str, torch.Tensor], grounding_dino_state_dict), TorchDevice.choose_torch_device()
|
||||
)
|
||||
segment_anything = GroundingSegmentAnythingDetector.build_segment_anything(
|
||||
cast(Dict[str, torch.Tensor], segment_anything_state_dict), TorchDevice.choose_torch_device()
|
||||
)
|
||||
detector = GroundingSegmentAnythingDetector(grounding_dino, segment_anything)
|
||||
masks = self._process_masks(masks)
|
||||
if self.apply_polygon_refinement:
|
||||
masks = self._apply_polygon_refinement(masks)
|
||||
|
||||
mask = detector.predict(
|
||||
input_image, self.prompt, self.box_threshold, self.text_threshold, self.nms_threshold
|
||||
)
|
||||
image_dto = context.images.save(mask)
|
||||
return masks
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
width=input_image.width,
|
||||
height=input_image.height,
|
||||
)
|
||||
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||
"""
|
||||
assert masks.dtype == torch.bool
|
||||
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||
masks, _ = masks.max(dim=1)
|
||||
# Split the first dimension into a list of masks.
|
||||
return list(masks.cpu().unbind(dim=0))
|
||||
|
||||
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""Apply polygon refinement to the masks.
|
||||
|
||||
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||
- Smooth the edges of the mask slightly.
|
||||
- Ensure that each mask consists of a single closed polygon
|
||||
- Removes small mask pieces.
|
||||
- Removes holes from the mask.
|
||||
"""
|
||||
# Convert tensor masks to np masks.
|
||||
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||
|
||||
# Apply polygon refinement.
|
||||
for idx, mask in enumerate(np_masks):
|
||||
shape = mask.shape
|
||||
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
np_masks[idx] = mask
|
||||
|
||||
# Convert np masks back to tensor masks.
|
||||
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||
|
||||
return masks
|
||||
|
||||
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||
"""Filter the detected masks based on the specified mask filter."""
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
|
||||
if self.mask_filter == "all":
|
||||
return masks
|
||||
elif self.mask_filter == "largest":
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||
return [masks[max_score_idx]]
|
||||
else:
|
||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.vto_workflow.extract_channel import ImageChannel, extract_channel
|
||||
from invokeai.backend.vto_workflow.overlay_pattern import multiply_images
|
||||
from invokeai.backend.vto_workflow.seamless_mapping import map_seamless_tiles
|
||||
|
||||
|
||||
@invocation("vto", title="Virtual Try-On", tags=["vto"], category="vto", version="1.1.0")
|
||||
class VTOInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Virtual try-on."""
|
||||
|
||||
original_image: ImageField = InputField(description="The input image")
|
||||
clothing_mask: ImageField = InputField(description="Clothing mask.")
|
||||
pattern_image: ImageField = InputField(description="Pattern image.")
|
||||
pattern_vertical_repeats: float = InputField(
|
||||
description="Number of vertical repeats for the pattern.", gt=0.01, default=1.0
|
||||
)
|
||||
|
||||
shading_max: float = InputField(
|
||||
description="The lightness of the light spots on the clothing. Default is 1.0. Typically in the range [0.7, 1.2]. Must be > shading_min",
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
)
|
||||
shading_min: float = InputField(
|
||||
description="The lightness of the dark spots on the clothing. Default id 0.5. Typically in the range [0.2, 0.7]",
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
mask_dilation: int = InputField(
|
||||
description="The number of pixels to dilate the mask by. Default is 1.",
|
||||
default=1,
|
||||
ge=0,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# TODO(ryand): Avoid all the unnecessary flip-flopping between PIL and numpy.
|
||||
original_image = context.images.get_pil(self.original_image.image_name)
|
||||
clothing_mask = context.images.get_pil(self.clothing_mask.image_name)
|
||||
pattern_image = context.images.get_pil(self.pattern_image.image_name)
|
||||
|
||||
shadows = extract_channel(np.array(original_image), ImageChannel.LAB_L)
|
||||
|
||||
# Clip the shadows to the 0.05 and 0.95 percentiles to eliminate outliers.
|
||||
shadows = np.clip(shadows, np.percentile(shadows, 5), np.percentile(shadows, 95))
|
||||
|
||||
# Normalize the shadows to the range [shading_min, shading_max].
|
||||
assert self.shading_min < self.shading_max
|
||||
shadows = shadows.astype(np.float32)
|
||||
shadows = (shadows - shadows.min()) / (shadows.max() - shadows.min())
|
||||
shadows = self.shading_min + (self.shading_max - self.shading_min) * shadows
|
||||
shadows = np.clip(shadows, 0.0, 1.0)
|
||||
shadows = (shadows * 255).astype(np.uint8)
|
||||
|
||||
expanded_pattern = map_seamless_tiles(
|
||||
seamless_tile=pattern_image,
|
||||
target_hw=(original_image.height, original_image.width),
|
||||
num_repeats_h=self.pattern_vertical_repeats,
|
||||
)
|
||||
|
||||
pattern_with_shadows = multiply_images(expanded_pattern, Image.fromarray(shadows))
|
||||
|
||||
# Dilate the mask.
|
||||
clothing_mask_np = np.array(clothing_mask)
|
||||
if self.mask_dilation > 0:
|
||||
clothing_mask_np = cv2.dilate(clothing_mask_np, np.ones((3, 3), np.uint8), iterations=self.mask_dilation)
|
||||
|
||||
# Merge the pattern with the model image.
|
||||
pattern_with_shadows_np = np.array(pattern_with_shadows)
|
||||
original_image_np = np.array(original_image)
|
||||
merged_image = np.where(clothing_mask_np[:, :, None], pattern_with_shadows_np, original_image_np)
|
||||
merged_image = Image.fromarray(merged_image)
|
||||
|
||||
image_dto = context.images.save(image=merged_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
custom_nodes_dir: Path to directory for custom nodes.
|
||||
style_presets_dir: Path to directory for style presets.
|
||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
@@ -153,6 +154,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
|
||||
|
||||
# LOGGING
|
||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||
@@ -300,6 +302,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def style_presets_path(self) -> Path:
|
||||
"""Path to the style presets directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.style_presets_dir)
|
||||
|
||||
@property
|
||||
def convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
|
||||
@@ -88,6 +88,8 @@ class QueueItemEventBase(QueueEventBase):
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the queue item")
|
||||
destination: str | None = Field(default=None, description="The destination of the queue item")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@@ -95,8 +97,6 @@ class InvocationEventBase(QueueItemEventBase):
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
@@ -114,6 +114,8 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -147,6 +149,8 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -184,6 +188,8 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -216,6 +222,8 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -253,6 +261,8 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
@@ -279,12 +289,14 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
origin=enqueue_result.batch.origin,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
|
||||
@@ -1,46 +1,44 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._queue = asyncio.Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._loop = loop
|
||||
|
||||
# We need to store a reference to the task so it doesn't get GC'd
|
||||
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
|
||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.remove)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
event = await self._queue.get()
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
# Leave the payloads as live pydantic models
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.image_files.image_files_common import (
|
||||
@@ -20,18 +19,12 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: Path
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__cache: dict[Path, PILImageType] = {}
|
||||
self.__cache_ids = Queue[Path]()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||
# Validate required output folders at launch
|
||||
self.__validate_storage_folders()
|
||||
@@ -103,7 +96,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if image_path.exists():
|
||||
send2trash(image_path)
|
||||
image_path.unlink()
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
@@ -111,7 +104,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||
|
||||
if thumbnail_path.exists():
|
||||
send2trash(thumbnail_path)
|
||||
thumbnail_path.unlink()
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
@@ -61,6 +63,8 @@ class InvocationServices:
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -85,3 +89,5 @@ class InvocationServices:
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
|
||||
@@ -2,7 +2,6 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
||||
@@ -70,7 +69,7 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
@@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
||||
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
||||
path_to_add = Path(f"{top}_{subfolder_rename}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
@@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
format: Optional[str] = Field(description="format of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
||||
@@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
"""Cancels all queue items with the given batch origin"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
|
||||
@@ -77,6 +77,14 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
@@ -195,6 +203,14 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
@@ -294,6 +310,8 @@ class SessionQueueStatus(BaseModel):
|
||||
class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
origin: str | None = Field(..., description="The origin of the batch")
|
||||
destination: str | None = Field(..., description="The destination of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
@@ -328,6 +346,12 @@ class CancelByBatchIDsResult(BaseModel):
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByOriginResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
@@ -433,6 +457,8 @@ class SessionQueueValueToInsert(NamedTuple):
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
destination: str | None
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@@ -453,6 +479,8 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
batch.origin, # origin
|
||||
batch.destination, # destination
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND origin == ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, origin)
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.origin == origin:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByOriginResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
@@ -541,7 +578,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -621,7 +660,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -633,6 +672,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -641,6 +682,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
destination=destination,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
|
||||
@@ -16,6 +16,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -49,6 +51,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.register_migration(build_migration_15())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration14Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_style_presets(cursor)
|
||||
|
||||
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store style presets."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS style_presets (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
preset_data TEXT NOT NULL,
|
||||
type TEXT NOT NULL DEFAULT "user",
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS style_presets
|
||||
AFTER UPDATE
|
||||
ON style_presets FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_14() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 13 to 14..
|
||||
|
||||
This migration does the following:
|
||||
- Create the table used to store style presets.
|
||||
"""
|
||||
migration_14 = Migration(
|
||||
from_version=13,
|
||||
to_version=14,
|
||||
callback=Migration14Callback(),
|
||||
)
|
||||
|
||||
return migration_14
|
||||
@@ -0,0 +1,34 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration15Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_origin_col(cursor)
|
||||
|
||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
|
||||
|
||||
|
||||
def build_migration_15() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 14 to 15.
|
||||
|
||||
This migration does the following:
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
migration_15 = Migration(
|
||||
from_version=14,
|
||||
to_version=15,
|
||||
callback=Migration15Callback(),
|
||||
)
|
||||
|
||||
return migration_15
|
||||
|
After Width: | Height: | Size: 98 KiB |
|
After Width: | Height: | Size: 138 KiB |
|
After Width: | Height: | Size: 122 KiB |
|
After Width: | Height: | Size: 123 KiB |
|
After Width: | Height: | Size: 160 KiB |
|
After Width: | Height: | Size: 146 KiB |
|
After Width: | Height: | Size: 119 KiB |
|
After Width: | Height: | Size: 117 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 46 KiB |
|
After Width: | Height: | Size: 79 KiB |
|
After Width: | Height: | Size: 156 KiB |
|
After Width: | Height: | Size: 141 KiB |
|
After Width: | Height: | Size: 96 KiB |
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 88 KiB |
|
After Width: | Height: | Size: 107 KiB |
|
After Width: | Height: | Size: 132 KiB |
@@ -0,0 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class StylePresetImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
"""Retrieves a style preset image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
"""Gets the internal path to a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
"""Gets the URL to fetch a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
"""Saves a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset image."""
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
class StylePresetImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not deleted"):
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import (
|
||||
StylePresetImageFileDeleteException,
|
||||
StylePresetImageFileNotFoundException,
|
||||
StylePresetImageFileSaveException,
|
||||
)
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
|
||||
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, style_preset_images_folder: Path):
|
||||
self._style_preset_images_folder = style_preset_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileSaveException from e
|
||||
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
|
||||
if style_preset.type is PresetType.Default:
|
||||
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
|
||||
path = default_images_dir / (style_preset.name + ".png")
|
||||
else:
|
||||
path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
path = self.get_path(style_preset_id)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
|
||||
|
||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
||||
url += f"?{uuid_string()}"
|
||||
|
||||
return url
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise StylePresetImageFileNotFoundException
|
||||
|
||||
path.unlink()
|
||||
|
||||
except StylePresetImageFileNotFoundException as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)
|
||||
@@ -0,0 +1,146 @@
|
||||
[
|
||||
{
|
||||
"name": "Photography (General)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Studio Lighting)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Landscape)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Portrait)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Black and White)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
|
||||
"negative_prompt": "painting, digital art. sketch, colour+"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Architectural Visualization",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Fantasy)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Sci-Fi)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Character)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Painterly)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
|
||||
"negative_prompt": "photo. smooth. border. frame"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Environment Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
|
||||
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Interior Design (Visualization)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
|
||||
"negative_prompt": "photo, distorted. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Product Rendering",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
|
||||
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Sketch",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
|
||||
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Line Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
|
||||
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Anime",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
|
||||
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Illustration",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
|
||||
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vehicles",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
|
||||
"negative_prompt": "sketch. digital art. greyscale. painting"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetRecordsStorageBase(ABC):
|
||||
"""Base class for style preset storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Get style preset by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
"""Creates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
"""Creates many style presets."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
"""Updates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -0,0 +1,139 @@
|
||||
import codecs
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import UploadFile
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class StylePresetNotFoundError(Exception):
|
||||
"""Raised when a style preset is not found"""
|
||||
|
||||
|
||||
class PresetData(BaseModel, extra="forbid"):
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
|
||||
|
||||
PresetDataValidator = TypeAdapter(PresetData)
|
||||
|
||||
|
||||
class PresetType(str, Enum, metaclass=MetaEnum):
|
||||
User = "user"
|
||||
Default = "default"
|
||||
Project = "project"
|
||||
|
||||
|
||||
class StylePresetChanges(BaseModel, extra="forbid"):
|
||||
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
||||
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
||||
type: Optional[PresetType] = Field(description="The updated type of the style preset")
|
||||
|
||||
|
||||
class StylePresetWithoutId(BaseModel):
|
||||
name: str = Field(description="The name of the style preset.")
|
||||
preset_data: PresetData = Field(description="The preset data")
|
||||
type: PresetType = Field(description="The type of style preset")
|
||||
|
||||
|
||||
class StylePresetRecordDTO(StylePresetWithoutId):
|
||||
id: str = Field(description="The style preset ID.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
||||
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
|
||||
return StylePresetRecordDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
||||
|
||||
|
||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
||||
image: Optional[str] = Field(description="The path for image")
|
||||
|
||||
|
||||
class StylePresetImportRow(BaseModel):
|
||||
name: str = Field(min_length=1, description="The name of the preset.")
|
||||
positive_prompt: str = Field(
|
||||
default="",
|
||||
description="The positive prompt for the preset.",
|
||||
validation_alias=AliasChoices("positive_prompt", "prompt"),
|
||||
)
|
||||
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
|
||||
|
||||
|
||||
StylePresetImportList = list[StylePresetImportRow]
|
||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(ValueError):
|
||||
"""Raised when an unsupported file type is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPresetImportDataError(ValueError):
|
||||
"""Raised when invalid preset import data is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
|
||||
"""Parses style presets from a file. The file must be a CSV or JSON file.
|
||||
|
||||
If CSV, the file must have the following columns:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
If JSON, the file must be a list of objects with the following keys:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
Args:
|
||||
file (UploadFile): The file to parse.
|
||||
|
||||
Returns:
|
||||
list[StylePresetWithoutId]: The parsed style presets.
|
||||
|
||||
Raises:
|
||||
UnsupportedFileTypeError: If the file type is not supported.
|
||||
InvalidPresetImportDataError: If the data in the file is invalid.
|
||||
"""
|
||||
if file.content_type not in ["text/csv", "application/json"]:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if file.content_type == "text/csv":
|
||||
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
|
||||
data = list(csv_reader)
|
||||
else: # file.content_type == "application/json":
|
||||
json_data = await file.read()
|
||||
data = json.loads(json_data)
|
||||
|
||||
try:
|
||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
||||
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
|
||||
for imported in imported_presets:
|
||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
except pydantic.ValidationError as e:
|
||||
if file.content_type == "text/csv":
|
||||
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
else: # file.content_type == "application/json":
|
||||
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
raise InvalidPresetImportDataError(msg) from e
|
||||
finally:
|
||||
file.file.close()
|
||||
|
||||
return style_presets
|
||||
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_style_presets()
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET name = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.name, style_preset_id),
|
||||
)
|
||||
|
||||
# Change the preset data for a style preset
|
||||
if changes.preset_data is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET preset_data = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
self._cursor.execute(main_query, (type,))
|
||||
else:
|
||||
self._cursor.execute(main_query)
|
||||
|
||||
rows = self._cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Next, parse and create the default style presets
|
||||
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
for preset in presets:
|
||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
||||
self.create(style_preset)
|
||||
@@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
"""Gets the URL for a style preset image"""
|
||||
pass
|
||||
|
||||
@@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 824.1970602278849,
|
||||
"y": 146.98251001061735
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 822.9899179655476,
|
||||
"y": 360.9657214885052
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "flux_text_to_image",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1216.3900791301849,
|
||||
"y": 5.500841807102248
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -81,7 +81,7 @@ def get_openapi_func(
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": invocation_output_map_properties,
|
||||
"properties": dict(sorted(invocation_output_map_properties.items())),
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
|
||||
32
invokeai/backend/flux/math.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
117
invokeai/backend/flux/model.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
310
invokeai/backend/flux/modules/autoencoder.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
33
invokeai/backend/flux/modules/conditioner.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
||||
253
invokeai/backend/flux/modules/layers.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
167
invokeai/backend/flux/sampling.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
71
invokeai/backend/flux/util.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
|
||||
ae_params = {
|
||||
"flux": AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
params = {
|
||||
"flux-dev": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
"flux-schnell": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
width=518,
|
||||
height=518,
|
||||
resize_target=False,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=14,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||
return image
|
||||
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
np_image = np_image[:, :, ::-1] / 255.0
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
|
||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||
depth_map = Image.fromarray(depth_map)
|
||||
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
@@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class DepthAnythingPipeline(RawModel):
|
||||
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||
for Invoke's Model Management System"""
|
||||
|
||||
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
||||
self._pipeline = pipeline
|
||||
|
||||
def generate_depth(self, image: Image.Image) -> Image.Image:
|
||||
depth_map = self._pipeline(image)["depth"]
|
||||
assert isinstance(depth_map, Image.Image)
|
||||
return depth_map
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
@@ -1,145 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module."""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
if self.bn:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block."""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size = size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
@@ -1,183 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
|
||||
|
||||
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size=None):
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||
super(DPTHead, self).__init__()
|
||||
|
||||
self.nclass = nclass
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
for out_channel in out_channels
|
||||
]
|
||||
)
|
||||
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
if nclass > 1:
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.scratch.output_conv1 = nn.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
else:
|
||||
x = x[0]
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
out = self.scratch.output_conv2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPT_DINOv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
out_channels,
|
||||
encoder="vitl",
|
||||
use_bn=False,
|
||||
use_clstoken=False,
|
||||
):
|
||||
super(DPT_DINOv2, self).__init__()
|
||||
|
||||
assert encoder in ["vits", "vitb", "vitl"]
|
||||
|
||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||
# if use_local:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# torchhub_path / "facebookresearch_dinov2_main",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# source="local",
|
||||
# pretrained=False,
|
||||
# )
|
||||
# else:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# "facebookresearch/dinov2",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# )
|
||||
|
||||
self.pretrained = torch.hub.load(
|
||||
"facebookresearch/dinov2",
|
||||
"dinov2_{:}14".format(encoder),
|
||||
)
|
||||
|
||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||
|
||||
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
||||
|
||||
patch_h, patch_w = h // 14, w // 14
|
||||
|
||||
depth = self.depth_head(features, patch_h, patch_w)
|
||||
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
||||
depth = F.relu(depth)
|
||||
|
||||
return depth.squeeze(1)
|
||||
@@ -1,227 +0,0 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
||||
|
||||
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
||||
than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
# sample["semseg_mask"] = cv2.resize(
|
||||
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
# )
|
||||
sample["semseg_mask"] = F.interpolate(
|
||||
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
||||
).numpy()[0, 0]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
# sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
# print(sample['image'].shape, sample['depth'].shape)
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std."""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
||||
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
||||
|
||||
return sample
|
||||
@@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box helper class."""
|
||||
|
||||
xmin: int
|
||||
ymin: int
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Detection result from Grounding DINO."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
model_config = ConfigDict(
|
||||
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class GroundingDinoPipeline(RawModel):
|
||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||
management system.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||
self._pipeline = pipeline
|
||||
|
||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||
assert results is not None
|
||||
results = [DetectionResult.model_validate(result) for result in results]
|
||||
return results
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||
# CUDA.
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix the circular import issue.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
@@ -1,43 +0,0 @@
|
||||
batch_size = 1
|
||||
modelname = "groundingdino"
|
||||
backbone = "swin_B_384_22k"
|
||||
position_embedding = "sine"
|
||||
pe_temperatureH = 20
|
||||
pe_temperatureW = 20
|
||||
return_interm_indices = [1, 2, 3]
|
||||
backbone_freeze_keywords = None
|
||||
enc_layers = 6
|
||||
dec_layers = 6
|
||||
pre_norm = False
|
||||
dim_feedforward = 2048
|
||||
hidden_dim = 256
|
||||
dropout = 0.0
|
||||
nheads = 8
|
||||
num_queries = 900
|
||||
query_dim = 4
|
||||
num_patterns = 0
|
||||
num_feature_levels = 4
|
||||
enc_n_points = 4
|
||||
dec_n_points = 4
|
||||
two_stage_type = "standard"
|
||||
two_stage_bbox_embed_share = False
|
||||
two_stage_class_embed_share = False
|
||||
transformer_activation = "relu"
|
||||
dec_pred_bbox_embed_share = True
|
||||
dn_box_noise_scale = 1.0
|
||||
dn_label_noise_ratio = 0.5
|
||||
dn_label_coef = 1.0
|
||||
dn_bbox_coef = 1.0
|
||||
embed_init_tgt = True
|
||||
dn_labelbook_size = 2000
|
||||
max_text_len = 256
|
||||
text_encoder_type = "bert-base-uncased"
|
||||
use_text_enhancer = True
|
||||
use_fusion_layer = True
|
||||
use_checkpoint = True
|
||||
use_transformer_ckpt = True
|
||||
use_text_cross_attention = True
|
||||
text_dropout = 0.0
|
||||
fusion_dropout = 0.0
|
||||
fusion_droppath = 0.1
|
||||
sub_sentence_present = True
|
||||
@@ -1,43 +0,0 @@
|
||||
batch_size = 1
|
||||
modelname = "groundingdino"
|
||||
backbone = "swin_T_224_1k"
|
||||
position_embedding = "sine"
|
||||
pe_temperatureH = 20
|
||||
pe_temperatureW = 20
|
||||
return_interm_indices = [1, 2, 3]
|
||||
backbone_freeze_keywords = None
|
||||
enc_layers = 6
|
||||
dec_layers = 6
|
||||
pre_norm = False
|
||||
dim_feedforward = 2048
|
||||
hidden_dim = 256
|
||||
dropout = 0.0
|
||||
nheads = 8
|
||||
num_queries = 900
|
||||
query_dim = 4
|
||||
num_patterns = 0
|
||||
num_feature_levels = 4
|
||||
enc_n_points = 4
|
||||
dec_n_points = 4
|
||||
two_stage_type = "standard"
|
||||
two_stage_bbox_embed_share = False
|
||||
two_stage_class_embed_share = False
|
||||
transformer_activation = "relu"
|
||||
dec_pred_bbox_embed_share = True
|
||||
dn_box_noise_scale = 1.0
|
||||
dn_label_noise_ratio = 0.5
|
||||
dn_label_coef = 1.0
|
||||
dn_bbox_coef = 1.0
|
||||
embed_init_tgt = True
|
||||
dn_labelbook_size = 2000
|
||||
max_text_len = 256
|
||||
text_encoder_type = "bert-base-uncased"
|
||||
use_text_enhancer = True
|
||||
use_fusion_layer = True
|
||||
use_checkpoint = True
|
||||
use_transformer_ckpt = True
|
||||
use_text_cross_attention = True
|
||||
text_dropout = 0.0
|
||||
fusion_dropout = 0.0
|
||||
fusion_droppath = 0.1
|
||||
sub_sentence_present = True
|
||||
@@ -1,299 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.box_ops import box_xyxy_to_cxcywh
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import interpolate
|
||||
|
||||
|
||||
def crop(image, target, region):
|
||||
cropped_image = F.crop(image, *region)
|
||||
|
||||
target = target.copy()
|
||||
i, j, h, w = region
|
||||
|
||||
# should we do something wrt the original size?
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
fields = ["labels", "area", "iscrowd", "positive_map"]
|
||||
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
||||
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
||||
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
||||
cropped_boxes = cropped_boxes.clamp(min=0)
|
||||
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
||||
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
||||
target["area"] = area
|
||||
fields.append("boxes")
|
||||
|
||||
if "masks" in target:
|
||||
# FIXME should we update the area here if there are no boxes?
|
||||
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
||||
fields.append("masks")
|
||||
|
||||
# remove elements for which the boxes or masks that have zero area
|
||||
if "boxes" in target or "masks" in target:
|
||||
# favor boxes selection when defining which elements to keep
|
||||
# this is compatible with previous implementation
|
||||
if "boxes" in target:
|
||||
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
||||
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
||||
else:
|
||||
keep = target["masks"].flatten(1).any(1)
|
||||
|
||||
for field in fields:
|
||||
if field in target:
|
||||
target[field] = target[field][keep]
|
||||
|
||||
if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
|
||||
# for debug and visualization only.
|
||||
if "strings_positive" in target:
|
||||
target["strings_positive"] = [_i for _i, _j in zip(target["strings_positive"], keep, strict=False) if _j]
|
||||
|
||||
return cropped_image, target
|
||||
|
||||
|
||||
def hflip(image, target):
|
||||
flipped_image = F.hflip(image)
|
||||
|
||||
w, h = image.size
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
||||
target["boxes"] = boxes
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = target["masks"].flip(-1)
|
||||
|
||||
return flipped_image, target
|
||||
|
||||
|
||||
def resize(image, target, size, max_size=None):
|
||||
# size can be min_size (scalar) or (w, h) tuple
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size[::-1]
|
||||
else:
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)
|
||||
|
||||
size = get_size(image.size, size, max_size)
|
||||
rescaled_image = F.resize(image, size)
|
||||
|
||||
if target is None:
|
||||
return rescaled_image, None
|
||||
|
||||
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False))
|
||||
ratio_width, ratio_height = ratios
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
||||
target["boxes"] = scaled_boxes
|
||||
|
||||
if "area" in target:
|
||||
area = target["area"]
|
||||
scaled_area = area * (ratio_width * ratio_height)
|
||||
target["area"] = scaled_area
|
||||
|
||||
h, w = size
|
||||
target["size"] = torch.tensor([h, w])
|
||||
|
||||
if "masks" in target:
|
||||
target["masks"] = interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
||||
|
||||
return rescaled_image, target
|
||||
|
||||
|
||||
def pad(image, target, padding):
|
||||
# assumes that we only pad on the bottom right corners
|
||||
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
||||
if target is None:
|
||||
return padded_image, None
|
||||
target = target.copy()
|
||||
# should we do something wrt the original size?
|
||||
target["size"] = torch.tensor(padded_image.size[::-1])
|
||||
if "masks" in target:
|
||||
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
|
||||
return padded_image, target
|
||||
|
||||
|
||||
class ResizeDebug(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
return resize(img, target, self.size)
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
region = T.RandomCrop.get_params(img, self.size)
|
||||
return crop(img, target, region)
|
||||
|
||||
|
||||
class RandomSizeCrop(object):
|
||||
def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
|
||||
# respect_boxes: True to keep all boxes
|
||||
# False to tolerence box filter
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
self.respect_boxes = respect_boxes
|
||||
|
||||
def __call__(self, img: PIL.Image.Image, target: dict):
|
||||
init_boxes = len(target["boxes"])
|
||||
max_patience = 10
|
||||
for i in range(max_patience):
|
||||
w = random.randint(self.min_size, min(img.width, self.max_size))
|
||||
h = random.randint(self.min_size, min(img.height, self.max_size))
|
||||
region = T.RandomCrop.get_params(img, [h, w])
|
||||
result_img, result_target = crop(img, target, region)
|
||||
if not self.respect_boxes or len(result_target["boxes"]) == init_boxes or i == max_patience - 1:
|
||||
return result_img, result_target
|
||||
return result_img, result_target
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img, target):
|
||||
image_width, image_height = img.size
|
||||
crop_height, crop_width = self.size
|
||||
crop_top = int(round((image_height - crop_height) / 2.0))
|
||||
crop_left = int(round((image_width - crop_width) / 2.0))
|
||||
return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
|
||||
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return hflip(img, target)
|
||||
return img, target
|
||||
|
||||
|
||||
class RandomResize(object):
|
||||
def __init__(self, sizes, max_size=None):
|
||||
assert isinstance(sizes, (list, tuple))
|
||||
self.sizes = sizes
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, img, target=None):
|
||||
size = random.choice(self.sizes)
|
||||
return resize(img, target, size, self.max_size)
|
||||
|
||||
|
||||
class RandomPad(object):
|
||||
def __init__(self, max_pad):
|
||||
self.max_pad = max_pad
|
||||
|
||||
def __call__(self, img, target):
|
||||
pad_x = random.randint(0, self.max_pad)
|
||||
pad_y = random.randint(0, self.max_pad)
|
||||
return pad(img, target, (pad_x, pad_y))
|
||||
|
||||
|
||||
class RandomSelect(object):
|
||||
"""
|
||||
Randomly selects between transforms1 and transforms2,
|
||||
with probability p for transforms1 and (1 - p) for transforms2
|
||||
"""
|
||||
|
||||
def __init__(self, transforms1, transforms2, p=0.5):
|
||||
self.transforms1 = transforms1
|
||||
self.transforms2 = transforms2
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img, target):
|
||||
if random.random() < self.p:
|
||||
return self.transforms1(img, target)
|
||||
return self.transforms2(img, target)
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
def __call__(self, img, target):
|
||||
return F.to_tensor(img), target
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.eraser = T.RandomErasing(*args, **kwargs)
|
||||
|
||||
def __call__(self, img, target):
|
||||
return self.eraser(img), target
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, image, target=None):
|
||||
image = F.normalize(image, mean=self.mean, std=self.std)
|
||||
if target is None:
|
||||
return image, None
|
||||
target = target.copy()
|
||||
h, w = image.shape[-2:]
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = box_xyxy_to_cxcywh(boxes)
|
||||
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
|
||||
target["boxes"] = boxes
|
||||
return image, target
|
||||
|
||||
|
||||
class Compose(object):
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, image, target):
|
||||
for t in self.transforms:
|
||||
image, target = t(image, target)
|
||||
return image, target
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + "("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += " {0}".format(t)
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
@@ -1,17 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Conditional DETR
|
||||
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Copied from DETR (https://github.com/facebookresearch/detr)
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.groundingdino import (
|
||||
build_groundingdino,
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
from .backbone import build_backbone
|
||||
@@ -1,217 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Conditional DETR
|
||||
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Copied from DETR (https://github.com/facebookresearch/detr)
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.position_encoding import (
|
||||
build_position_encoding,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.models.GroundingDINO.backbone.swin_transformer import (
|
||||
build_swin_transformer,
|
||||
)
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other models than torchvision.models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
train_backbone: bool,
|
||||
num_channels: int,
|
||||
return_interm_indices: list,
|
||||
):
|
||||
super().__init__()
|
||||
for name, parameter in backbone.named_parameters():
|
||||
if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
||||
parameter.requires_grad_(False)
|
||||
|
||||
return_layers = {}
|
||||
for idx, layer_index in enumerate(return_interm_indices):
|
||||
return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
|
||||
|
||||
# if len:
|
||||
# if use_stage1_feature:
|
||||
# return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
# else:
|
||||
# return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
|
||||
# else:
|
||||
# return_layers = {'layer4': "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self.body(tensor_list.tensors)
|
||||
out: Dict[str, NestedTensor] = {}
|
||||
for name, x in xs.items():
|
||||
m = tensor_list.mask
|
||||
assert m is not None
|
||||
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
out[name] = NestedTensor(x, mask)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
train_backbone: bool,
|
||||
dilation: bool,
|
||||
return_interm_indices: list,
|
||||
batch_norm=FrozenBatchNorm2d,
|
||||
):
|
||||
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(),
|
||||
norm_layer=batch_norm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Why you can get here with name {}".format(name))
|
||||
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
|
||||
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
||||
num_channels_all = [256, 512, 1024, 2048]
|
||||
num_channels = num_channels_all[4 - len(return_interm_indices) :]
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.tensors.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
"""
|
||||
Useful args:
|
||||
- backbone: backbone name
|
||||
- lr_backbone:
|
||||
- dilation
|
||||
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
|
||||
- backbone_freeze_keywords:
|
||||
- use_checkpoint: for swin only for now
|
||||
|
||||
"""
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = True
|
||||
if not train_backbone:
|
||||
raise ValueError("Please set lr_backbone > 0")
|
||||
return_interm_indices = args.return_interm_indices
|
||||
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
||||
args.backbone_freeze_keywords
|
||||
use_checkpoint = getattr(args, "use_checkpoint", False)
|
||||
|
||||
if args.backbone in ["resnet50", "resnet101"]:
|
||||
backbone = Backbone(
|
||||
args.backbone,
|
||||
train_backbone,
|
||||
args.dilation,
|
||||
return_interm_indices,
|
||||
batch_norm=FrozenBatchNorm2d,
|
||||
)
|
||||
bb_num_channels = backbone.num_channels
|
||||
elif args.backbone in [
|
||||
"swin_T_224_1k",
|
||||
"swin_B_224_22k",
|
||||
"swin_B_384_22k",
|
||||
"swin_L_224_22k",
|
||||
"swin_L_384_22k",
|
||||
]:
|
||||
pretrain_img_size = int(args.backbone.split("_")[-2])
|
||||
backbone = build_swin_transformer(
|
||||
args.backbone,
|
||||
pretrain_img_size=pretrain_img_size,
|
||||
out_indices=tuple(return_interm_indices),
|
||||
dilation=False,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
|
||||
else:
|
||||
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
|
||||
|
||||
assert len(bb_num_channels) == len(
|
||||
return_interm_indices
|
||||
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
|
||||
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = bb_num_channels
|
||||
assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(
|
||||
type(bb_num_channels)
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
return model
|
||||
@@ -1,176 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# DINO
|
||||
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Conditional DETR
|
||||
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Copied from DETR (https://github.com/facebookresearch/detr)
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
mask = tensor_list.mask
|
||||
assert mask is not None
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
# if os.environ.get("SHILONG_AMP", None) == '1':
|
||||
# eps = 1e-4
|
||||
# else:
|
||||
# eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingSineHW(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperatureH = temperatureH
|
||||
self.temperatureW = temperatureW
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
mask = tensor_list.mask
|
||||
assert mask is not None
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode="floor")) / self.num_pos_feats)
|
||||
pos_x = x_embed[:, :, :, None] / dim_tx
|
||||
|
||||
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode="floor")) / self.num_pos_feats)
|
||||
pos_y = y_embed[:, :, :, None] / dim_ty
|
||||
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = (
|
||||
torch.cat(
|
||||
[
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(x.shape[0], 1, 1, 1)
|
||||
)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
N_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ("v2", "sine"):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSineHW(
|
||||
N_steps,
|
||||
temperatureH=args.pe_temperatureH,
|
||||
temperatureW=args.pe_temperatureW,
|
||||
normalize=True,
|
||||
)
|
||||
elif args.position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
@@ -1,766 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# DINO
|
||||
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import NestedTensor
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""Multilayer perceptron."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""Forward function.
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
"""Swin Transformer Block.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
"""Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
mask_matrix: Attention mask for cyclic shift.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""Patch Merging Layer
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
"""Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""A basic Swin Transformer layer for one stage.
|
||||
Args:
|
||||
dim (int): Number of feature channels
|
||||
depth (int): Depths of this stage.
|
||||
num_heads (int): Number of attention head.
|
||||
window_size (int): Local window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
"""Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, H*W, C).
|
||||
H, W: Spatial resolution of the input feature.
|
||||
"""
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device, dtype=x.dtype) # 1 Hp Wp 1
|
||||
h_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, attn_mask)
|
||||
else:
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
"""Swin Transformer backbone.
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
Args:
|
||||
pretrain_img_size (int): Input image size for training the pretrained model,
|
||||
used in absolute postion embedding. Default 224.
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
num_heads (tuple[int]): Number of attention head of each stage.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop_rate (float): Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.2,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
dilation=False,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.dilation = dilation
|
||||
|
||||
# if use_checkpoint:
|
||||
# print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None,
|
||||
)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
pretrain_img_size[0] // patch_size[0],
|
||||
pretrain_img_size[1] // patch_size[1],
|
||||
]
|
||||
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
|
||||
)
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
# prepare downsample list
|
||||
downsamplelist = [PatchMerging for i in range(self.num_layers)]
|
||||
downsamplelist[-1] = None
|
||||
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
|
||||
if self.dilation:
|
||||
downsamplelist[-2] = None
|
||||
num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
# dim=int(embed_dim * 2 ** i_layer),
|
||||
dim=num_features[i_layer],
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
# downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
downsample=downsamplelist[i_layer],
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
# num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
# add a norm layer for each output
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f"norm{i_layer}"
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 1 and self.ape:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
|
||||
if self.frozen_stages >= 2:
|
||||
self.pos_drop.eval()
|
||||
for i in range(0, self.frozen_stages - 1):
|
||||
m = self.layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# def init_weights(self, pretrained=None):
|
||||
# """Initialize the weights in backbone.
|
||||
# Args:
|
||||
# pretrained (str, optional): Path to pre-trained weights.
|
||||
# Defaults to None.
|
||||
# """
|
||||
|
||||
# def _init_weights(m):
|
||||
# if isinstance(m, nn.Linear):
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
# if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
# elif isinstance(m, nn.LayerNorm):
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
# nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
# if isinstance(pretrained, str):
|
||||
# self.apply(_init_weights)
|
||||
# logger = get_root_logger()
|
||||
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
# elif pretrained is None:
|
||||
# self.apply(_init_weights)
|
||||
# else:
|
||||
# raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward_raw(self, x):
|
||||
"""Forward function."""
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f"norm{i}")
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
# in:
|
||||
# torch.Size([2, 3, 1024, 1024])
|
||||
# outs:
|
||||
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
||||
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
||||
return tuple(outs)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
|
||||
"""Forward function."""
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
if self.ape:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic")
|
||||
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
||||
else:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
outs = []
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f"norm{i}")
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
# in:
|
||||
# torch.Size([2, 3, 1024, 1024])
|
||||
# out:
|
||||
# [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
|
||||
# torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
|
||||
|
||||
# collect for nesttensors
|
||||
outs_dict = {}
|
||||
for idx, out_i in enumerate(outs):
|
||||
m = tensor_list.mask
|
||||
assert m is not None
|
||||
mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
|
||||
outs_dict[idx] = NestedTensor(out_i, mask)
|
||||
|
||||
return outs_dict
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super(SwinTransformer, self).train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
|
||||
def build_swin_transformer(modelname, pretrain_img_size, **kw):
|
||||
assert modelname in [
|
||||
"swin_T_224_1k",
|
||||
"swin_B_224_22k",
|
||||
"swin_B_384_22k",
|
||||
"swin_L_224_22k",
|
||||
"swin_L_384_22k",
|
||||
]
|
||||
|
||||
model_para_dict = {
|
||||
"swin_T_224_1k": dict(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7),
|
||||
"swin_B_224_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7),
|
||||
"swin_B_384_22k": dict(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12),
|
||||
"swin_L_224_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7),
|
||||
"swin_L_384_22k": dict(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12),
|
||||
}
|
||||
kw_cgf = model_para_dict[modelname]
|
||||
kw_cgf.update(kw)
|
||||
model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
|
||||
x = torch.rand(2, 3, 1024, 1024)
|
||||
y = model.forward_raw(x)
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
x = torch.rand(2, 3, 384, 384)
|
||||
y = model.forward_raw(x)
|
||||
@@ -1,250 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
||||
|
||||
|
||||
class BertModelWarper(nn.Module):
|
||||
def __init__(self, bert_model):
|
||||
super().__init__()
|
||||
# self.bert = bert_modelc
|
||||
|
||||
self.config = bert_model.config
|
||||
self.embeddings = bert_model.embeddings
|
||||
self.encoder = bert_model.encoder
|
||||
self.pooler = bert_model.pooler
|
||||
|
||||
self.get_extended_attention_mask = bert_model.get_extended_attention_mask
|
||||
self.invert_attention_mask = bert_model.invert_attention_mask
|
||||
self.get_head_mask = bert_model.get_head_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
batch_size, seq_length = input_shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
class TextEncoderShell(nn.Module):
|
||||
def __init__(self, text_encoder):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
self.config = self.text_encoder.config
|
||||
|
||||
def forward(self, **kw):
|
||||
# feed into text encoder
|
||||
return self.text_encoder(**kw)
|
||||
|
||||
|
||||
def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
|
||||
"""Generate attention mask between each pair of special tokens
|
||||
Args:
|
||||
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
||||
special_tokens_mask (list): special tokens mask.
|
||||
Returns:
|
||||
torch.Tensor: attention mask between each special tokens.
|
||||
"""
|
||||
input_ids = tokenized["input_ids"]
|
||||
bs, num_token = input_ids.shape
|
||||
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||
for special_token in special_tokens_list:
|
||||
special_tokens_mask |= input_ids == special_token
|
||||
|
||||
# idxs: each row is a list of indices of special tokens
|
||||
idxs = torch.nonzero(special_tokens_mask)
|
||||
|
||||
# generate attention mask and positional ids
|
||||
attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
|
||||
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
||||
previous_col = 0
|
||||
for i in range(idxs.shape[0]):
|
||||
row, col = idxs[i]
|
||||
if (col == 0) or (col == num_token - 1):
|
||||
attention_mask[row, col, col] = True
|
||||
position_ids[row, col] = 0
|
||||
else:
|
||||
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
|
||||
position_ids[row, previous_col + 1 : col + 1] = torch.arange(0, col - previous_col, device=input_ids.device)
|
||||
|
||||
previous_col = col
|
||||
|
||||
# # padding mask
|
||||
# padding_mask = tokenized['attention_mask']
|
||||
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
|
||||
|
||||
return attention_mask, position_ids.to(torch.long)
|
||||
|
||||
|
||||
def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
|
||||
"""Generate attention mask between each pair of special tokens
|
||||
Args:
|
||||
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
||||
special_tokens_mask (list): special tokens mask.
|
||||
Returns:
|
||||
torch.Tensor: attention mask between each special tokens.
|
||||
"""
|
||||
input_ids = tokenized["input_ids"]
|
||||
bs, num_token = input_ids.shape
|
||||
# special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
|
||||
special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
|
||||
for special_token in special_tokens_list:
|
||||
special_tokens_mask |= input_ids == special_token
|
||||
|
||||
# idxs: each row is a list of indices of special tokens
|
||||
idxs = torch.nonzero(special_tokens_mask)
|
||||
|
||||
# generate attention mask and positional ids
|
||||
attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
|
||||
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
||||
cate_to_token_mask_list = [[] for _ in range(bs)]
|
||||
previous_col = 0
|
||||
for i in range(idxs.shape[0]):
|
||||
row, col = idxs[i]
|
||||
if (col == 0) or (col == num_token - 1):
|
||||
attention_mask[row, col, col] = True
|
||||
position_ids[row, col] = 0
|
||||
else:
|
||||
attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
|
||||
position_ids[row, previous_col + 1 : col + 1] = torch.arange(0, col - previous_col, device=input_ids.device)
|
||||
c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
|
||||
c2t_maski[previous_col + 1 : col] = True
|
||||
cate_to_token_mask_list[row].append(c2t_maski)
|
||||
previous_col = col
|
||||
|
||||
cate_to_token_mask_list = [
|
||||
torch.stack(cate_to_token_mask_listi, dim=0) for cate_to_token_mask_listi in cate_to_token_mask_list
|
||||
]
|
||||
|
||||
# # padding mask
|
||||
# padding_mask = tokenized['attention_mask']
|
||||
# attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
|
||||
|
||||
return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
|
||||
@@ -1,295 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import DropPath
|
||||
|
||||
|
||||
class FeatureResizer(nn.Module):
|
||||
"""
|
||||
This class takes as input a set of embeddings of dimension C1 and outputs a set of
|
||||
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
|
||||
"""
|
||||
|
||||
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
|
||||
super().__init__()
|
||||
self.do_ln = do_ln
|
||||
# Object feature encoding
|
||||
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
|
||||
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, encoder_features):
|
||||
x = self.fc(encoder_features)
|
||||
if self.do_ln:
|
||||
x = self.layer_norm(x)
|
||||
output = self.dropout(x)
|
||||
return output
|
||||
|
||||
|
||||
def l1norm(X, dim, eps=1e-8):
|
||||
"""L1-normalize columns of X"""
|
||||
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def l2norm(X, dim, eps=1e-8):
|
||||
"""L2-normalize columns of X"""
|
||||
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
||||
X = torch.div(X, norm)
|
||||
return X
|
||||
|
||||
|
||||
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
|
||||
"""
|
||||
query: (n_context, queryL, d)
|
||||
context: (n_context, sourceL, d)
|
||||
"""
|
||||
_, queryL = query.size(0), query.size(1)
|
||||
batch_size, sourceL = context.size(0), context.size(1)
|
||||
|
||||
# Get attention
|
||||
# --> (batch, d, queryL)
|
||||
queryT = torch.transpose(query, 1, 2)
|
||||
|
||||
# (batch, sourceL, d)(batch, d, queryL)
|
||||
# --> (batch, sourceL, queryL)
|
||||
attn = torch.bmm(context, queryT)
|
||||
if raw_feature_norm == "softmax":
|
||||
# --> (batch*sourceL, queryL)
|
||||
attn = attn.view(batch_size * sourceL, queryL)
|
||||
attn = nn.Softmax()(attn)
|
||||
# --> (batch, sourceL, queryL)
|
||||
attn = attn.view(batch_size, sourceL, queryL)
|
||||
elif raw_feature_norm == "l2norm":
|
||||
attn = l2norm(attn, 2)
|
||||
elif raw_feature_norm == "clipped_l2norm":
|
||||
attn = nn.LeakyReLU(0.1)(attn)
|
||||
attn = l2norm(attn, 2)
|
||||
else:
|
||||
raise ValueError("unknown first norm type:", raw_feature_norm)
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = torch.transpose(attn, 1, 2).contiguous()
|
||||
# --> (batch*queryL, sourceL)
|
||||
attn = attn.view(batch_size * queryL, sourceL)
|
||||
attn = nn.Softmax()(attn * smooth)
|
||||
# --> (batch, queryL, sourceL)
|
||||
attn = attn.view(batch_size, queryL, sourceL)
|
||||
# --> (batch, sourceL, queryL)
|
||||
attnT = torch.transpose(attn, 1, 2).contiguous()
|
||||
|
||||
# --> (batch, d, sourceL)
|
||||
contextT = torch.transpose(context, 1, 2)
|
||||
# (batch x d x sourceL)(batch x sourceL x queryL)
|
||||
# --> (batch, d, queryL)
|
||||
weightedContext = torch.bmm(contextT, attnT)
|
||||
# --> (batch, queryL, d)
|
||||
weightedContext = torch.transpose(weightedContext, 1, 2)
|
||||
|
||||
return weightedContext, attnT
|
||||
|
||||
|
||||
class BiMultiHeadAttention(nn.Module):
|
||||
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
|
||||
super(BiMultiHeadAttention, self).__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.v_dim = v_dim
|
||||
self.l_dim = l_dim
|
||||
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and \
|
||||
`num_heads`: {self.num_heads})."
|
||||
self.scale = self.head_dim ** (-0.5)
|
||||
self.dropout = dropout
|
||||
|
||||
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
||||
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
||||
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
||||
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
||||
|
||||
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
|
||||
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
|
||||
|
||||
self.stable_softmax_2d = True
|
||||
self.clamp_min_for_underflow = True
|
||||
self.clamp_max_for_overflow = True
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def _reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
self.v_proj.bias.data.fill_(0)
|
||||
nn.init.xavier_uniform_(self.l_proj.weight)
|
||||
self.l_proj.bias.data.fill_(0)
|
||||
nn.init.xavier_uniform_(self.values_v_proj.weight)
|
||||
self.values_v_proj.bias.data.fill_(0)
|
||||
nn.init.xavier_uniform_(self.values_l_proj.weight)
|
||||
self.values_l_proj.bias.data.fill_(0)
|
||||
nn.init.xavier_uniform_(self.out_v_proj.weight)
|
||||
self.out_v_proj.bias.data.fill_(0)
|
||||
nn.init.xavier_uniform_(self.out_l_proj.weight)
|
||||
self.out_l_proj.bias.data.fill_(0)
|
||||
|
||||
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
v (_type_): bs, n_img, dim
|
||||
l (_type_): bs, n_text, dim
|
||||
attention_mask_v (_type_, optional): _description_. bs, n_img
|
||||
attention_mask_l (_type_, optional): _description_. bs, n_text
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
||||
# import ipdb; ipdb.set_trace()
|
||||
bsz, tgt_len, _ = v.size()
|
||||
|
||||
query_states = self.v_proj(v) * self.scale
|
||||
key_states = self._shape(self.l_proj(l), -1, bsz)
|
||||
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
|
||||
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_v_states = value_v_states.view(*proj_shape)
|
||||
value_l_states = value_l_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, \
|
||||
but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if self.stable_softmax_2d:
|
||||
attn_weights = attn_weights - attn_weights.max()
|
||||
|
||||
if self.clamp_min_for_underflow:
|
||||
attn_weights = torch.clamp(
|
||||
attn_weights, min=-50000
|
||||
) # Do not increase -50000, data type half has quite limited range
|
||||
if self.clamp_max_for_overflow:
|
||||
attn_weights = torch.clamp(
|
||||
attn_weights, max=50000
|
||||
) # Do not increase 50000, data type half has quite limited range
|
||||
|
||||
attn_weights_T = attn_weights.transpose(1, 2)
|
||||
attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
|
||||
if self.clamp_min_for_underflow:
|
||||
attn_weights_l = torch.clamp(
|
||||
attn_weights_l, min=-50000
|
||||
) # Do not increase -50000, data type half has quite limited range
|
||||
if self.clamp_max_for_overflow:
|
||||
attn_weights_l = torch.clamp(
|
||||
attn_weights_l, max=50000
|
||||
) # Do not increase 50000, data type half has quite limited range
|
||||
|
||||
# mask vison for language
|
||||
if attention_mask_v is not None:
|
||||
attention_mask_v = attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
||||
attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
|
||||
|
||||
attn_weights_l = attn_weights_l.softmax(dim=-1)
|
||||
|
||||
# mask language for vision
|
||||
if attention_mask_l is not None:
|
||||
attention_mask_l = attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
||||
attn_weights.masked_fill_(attention_mask_l, float("-inf"))
|
||||
attn_weights_v = attn_weights.softmax(dim=-1)
|
||||
|
||||
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
|
||||
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
|
||||
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
|
||||
|
||||
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, \
|
||||
but is {attn_output_v.size()}"
|
||||
)
|
||||
|
||||
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, \
|
||||
but is {attn_output_l.size()}"
|
||||
)
|
||||
|
||||
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output_v = attn_output_v.transpose(1, 2)
|
||||
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
|
||||
attn_output_l = attn_output_l.transpose(1, 2)
|
||||
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
|
||||
|
||||
attn_output_v = self.out_v_proj(attn_output_v)
|
||||
attn_output_l = self.out_l_proj(attn_output_l)
|
||||
|
||||
return attn_output_v, attn_output_l
|
||||
|
||||
|
||||
# Bi-Direction MHA (text->image, image->text)
|
||||
class BiAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
v_dim,
|
||||
l_dim,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.1,
|
||||
drop_path=0.0,
|
||||
init_values=1e-4,
|
||||
cfg=None,
|
||||
):
|
||||
"""
|
||||
Inputs:
|
||||
embed_dim - Dimensionality of input and attention feature vectors
|
||||
hidden_dim - Dimensionality of hidden layer in feed-forward network
|
||||
(usually 2-4x larger than embed_dim)
|
||||
num_heads - Number of heads to use in the Multi-Head Attention block
|
||||
dropout - Amount of dropout to apply in the feed-forward network
|
||||
"""
|
||||
super(BiAttentionBlock, self).__init__()
|
||||
|
||||
# pre layer norm
|
||||
self.layer_norm_v = nn.LayerNorm(v_dim)
|
||||
self.layer_norm_l = nn.LayerNorm(l_dim)
|
||||
self.attn = BiMultiHeadAttention(
|
||||
v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
|
||||
)
|
||||
|
||||
# add layer scale for training stability
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
|
||||
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
|
||||
|
||||
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
||||
v = self.layer_norm_v(v)
|
||||
l = self.layer_norm_l(l)
|
||||
delta_v, delta_l = self.attn(v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l)
|
||||
# v, l = v + delta_v, l + delta_l
|
||||
v = v + self.drop_path(self.gamma_v * delta_v)
|
||||
l = l + self.drop_path(self.gamma_l * delta_l)
|
||||
return v, l
|
||||
|
||||
# def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
|
||||
@@ -1,362 +0,0 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Grounding DINO
|
||||
# url: https://github.com/IDEA-Research/GroundingDINO
|
||||
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Conditional DETR model and criterion classes.
|
||||
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from DETR (https://github.com/facebookresearch/detr)
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
|
||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util import get_tokenlizer
|
||||
from invokeai.backend.image_util.grounding_segment_anything.groundingdino.util.misc import (
|
||||
NestedTensor,
|
||||
inverse_sigmoid,
|
||||
nested_tensor_from_tensor_list,
|
||||
)
|
||||
|
||||
from ..registry import MODULE_BUILD_FUNCS
|
||||
from .backbone import build_backbone
|
||||
from .bertwarper import BertModelWarper, generate_masks_with_special_tokens_and_transfer_map
|
||||
from .transformer import build_transformer
|
||||
from .utils import MLP, ContrastiveEmbed
|
||||
|
||||
|
||||
class GroundingDINO(nn.Module):
|
||||
"""This is the Cross-Attention Detector module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
transformer,
|
||||
num_queries,
|
||||
aux_loss=False,
|
||||
iter_update=False,
|
||||
query_dim=2,
|
||||
num_feature_levels=1,
|
||||
nheads=8,
|
||||
# two stage
|
||||
two_stage_type="no", # ['no', 'standard']
|
||||
dec_pred_bbox_embed_share=True,
|
||||
two_stage_class_embed_share=True,
|
||||
two_stage_bbox_embed_share=True,
|
||||
num_patterns=0,
|
||||
dn_number=100,
|
||||
dn_box_noise_scale=0.4,
|
||||
dn_label_noise_ratio=0.5,
|
||||
dn_labelbook_size=100,
|
||||
text_encoder_type="bert-base-uncased",
|
||||
sub_sentence_present=True,
|
||||
max_text_len=256,
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbone: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.transformer = transformer
|
||||
self.hidden_dim = hidden_dim = transformer.d_model
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.nheads = nheads
|
||||
self.max_text_len = 256
|
||||
self.sub_sentence_present = sub_sentence_present
|
||||
|
||||
# setting query dim
|
||||
self.query_dim = query_dim
|
||||
assert query_dim == 4
|
||||
|
||||
# for dn training
|
||||
self.num_patterns = num_patterns
|
||||
self.dn_number = dn_number
|
||||
self.dn_box_noise_scale = dn_box_noise_scale
|
||||
self.dn_label_noise_ratio = dn_label_noise_ratio
|
||||
self.dn_labelbook_size = dn_labelbook_size
|
||||
|
||||
# bert
|
||||
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
|
||||
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
|
||||
self.bert.pooler.dense.weight.requires_grad_(False)
|
||||
self.bert.pooler.dense.bias.requires_grad_(False)
|
||||
self.bert = BertModelWarper(bert_model=self.bert)
|
||||
|
||||
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
|
||||
nn.init.constant_(self.feat_map.bias.data, 0)
|
||||
nn.init.xavier_uniform_(self.feat_map.weight.data)
|
||||
# freeze
|
||||
|
||||
# special tokens
|
||||
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
|
||||
|
||||
# prepare input projection layers
|
||||
if num_feature_levels > 1:
|
||||
num_backbone_outs = len(backbone.num_channels)
|
||||
input_proj_list = []
|
||||
for _ in range(num_backbone_outs):
|
||||
in_channels = backbone.num_channels[_]
|
||||
input_proj_list.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
|
||||
nn.GroupNorm(32, hidden_dim),
|
||||
)
|
||||
)
|
||||
for _ in range(num_feature_levels - num_backbone_outs):
|
||||
input_proj_list.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(32, hidden_dim),
|
||||
)
|
||||
)
|
||||
in_channels = hidden_dim
|
||||
self.input_proj = nn.ModuleList(input_proj_list)
|
||||
else:
|
||||
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
|
||||
self.input_proj = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
|
||||
nn.GroupNorm(32, hidden_dim),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.backbone = backbone
|
||||
self.aux_loss = aux_loss
|
||||
self.box_pred_damping = None
|
||||
|
||||
self.iter_update = iter_update
|
||||
assert iter_update, "Why not iter_update?"
|
||||
|
||||
# prepare pred layers
|
||||
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
|
||||
# prepare class & box embed
|
||||
_class_embed = ContrastiveEmbed()
|
||||
|
||||
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
||||
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
|
||||
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
|
||||
|
||||
if dec_pred_bbox_embed_share:
|
||||
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
|
||||
else:
|
||||
box_embed_layerlist = [copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)]
|
||||
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
|
||||
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
|
||||
self.class_embed = nn.ModuleList(class_embed_layerlist)
|
||||
self.transformer.decoder.bbox_embed = self.bbox_embed
|
||||
self.transformer.decoder.class_embed = self.class_embed
|
||||
|
||||
# two stage
|
||||
self.two_stage_type = two_stage_type
|
||||
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(two_stage_type)
|
||||
if two_stage_type != "no":
|
||||
if two_stage_bbox_embed_share:
|
||||
assert dec_pred_bbox_embed_share
|
||||
self.transformer.enc_out_bbox_embed = _bbox_embed
|
||||
else:
|
||||
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
|
||||
|
||||
if two_stage_class_embed_share:
|
||||
assert dec_pred_bbox_embed_share
|
||||
self.transformer.enc_out_class_embed = _class_embed
|
||||
else:
|
||||
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
|
||||
|
||||
self.refpoint_embed = None
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
# init input_proj
|
||||
for proj in self.input_proj:
|
||||
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
||||
nn.init.constant_(proj[0].bias, 0)
|
||||
|
||||
def init_ref_points(self, use_num_queries):
|
||||
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
|
||||
|
||||
def forward(self, samples: NestedTensor, targets: List = None, **kw):
|
||||
"""The forward expects a NestedTensor, which consists of:
|
||||
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
||||
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
||||
|
||||
It returns a dict with the following elements:
|
||||
- "pred_logits": the classification logits (including no-object) for all queries.
|
||||
Shape= [batch_size x num_queries x num_classes]
|
||||
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
||||
(center_x, center_y, width, height). These values are normalized in [0, 1],
|
||||
relative to the size of each individual image (disregarding possible padding).
|
||||
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
||||
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
||||
dictionnaries containing the two above keys for each decoder layer.
|
||||
"""
|
||||
if targets is None:
|
||||
captions = kw["captions"]
|
||||
else:
|
||||
captions = [t["caption"] for t in targets]
|
||||
len(captions)
|
||||
|
||||
# encoder texts
|
||||
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(samples.device)
|
||||
(
|
||||
text_self_attention_masks,
|
||||
position_ids,
|
||||
cate_to_token_mask_list,
|
||||
) = generate_masks_with_special_tokens_and_transfer_map(tokenized, self.specical_tokens, self.tokenizer)
|
||||
|
||||
if text_self_attention_masks.shape[1] > self.max_text_len:
|
||||
text_self_attention_masks = text_self_attention_masks[:, : self.max_text_len, : self.max_text_len]
|
||||
position_ids = position_ids[:, : self.max_text_len]
|
||||
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
|
||||
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
|
||||
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
|
||||
|
||||
# extract text embeddings
|
||||
if self.sub_sentence_present:
|
||||
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
|
||||
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
|
||||
tokenized_for_encoder["position_ids"] = position_ids
|
||||
else:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
tokenized_for_encoder = tokenized
|
||||
|
||||
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
|
||||
|
||||
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
|
||||
text_token_mask = tokenized.attention_mask.bool() # bs, 195
|
||||
# text_token_mask: True for nomask, False for mask
|
||||
# text_self_attention_masks: True for nomask, False for mask
|
||||
|
||||
if encoded_text.shape[1] > self.max_text_len:
|
||||
encoded_text = encoded_text[:, : self.max_text_len, :]
|
||||
text_token_mask = text_token_mask[:, : self.max_text_len]
|
||||
position_ids = position_ids[:, : self.max_text_len]
|
||||
text_self_attention_masks = text_self_attention_masks[:, : self.max_text_len, : self.max_text_len]
|
||||
|
||||
text_dict = {
|
||||
"encoded_text": encoded_text, # bs, 195, d_model
|
||||
"text_token_mask": text_token_mask, # bs, 195
|
||||
"position_ids": position_ids, # bs, 195
|
||||
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
|
||||
}
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if isinstance(samples, (list, torch.Tensor)):
|
||||
samples = nested_tensor_from_tensor_list(samples)
|
||||
features, poss = self.backbone(samples)
|
||||
|
||||
srcs = []
|
||||
masks = []
|
||||
for l, feat in enumerate(features):
|
||||
src, mask = feat.decompose()
|
||||
srcs.append(self.input_proj[l](src))
|
||||
masks.append(mask)
|
||||
assert mask is not None
|
||||
if self.num_feature_levels > len(srcs):
|
||||
_len_srcs = len(srcs)
|
||||
for l in range(_len_srcs, self.num_feature_levels):
|
||||
if l == _len_srcs:
|
||||
src = self.input_proj[l](features[-1].tensors)
|
||||
else:
|
||||
src = self.input_proj[l](srcs[-1])
|
||||
m = samples.mask
|
||||
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
|
||||
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
||||
srcs.append(src)
|
||||
masks.append(mask)
|
||||
poss.append(pos_l)
|
||||
|
||||
input_query_bbox = input_query_label = attn_mask = None
|
||||
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
|
||||
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
|
||||
)
|
||||
|
||||
# deformable-detr-like anchor update
|
||||
outputs_coord_list = []
|
||||
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
|
||||
layer_delta_unsig = layer_bbox_embed(layer_hs)
|
||||
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
|
||||
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
|
||||
outputs_coord_list.append(layer_outputs_unsig)
|
||||
outputs_coord_list = torch.stack(outputs_coord_list)
|
||||
|
||||
# output
|
||||
outputs_class = torch.stack(
|
||||
[layer_cls_embed(layer_hs, text_dict) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)]
|
||||
)
|
||||
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
|
||||
|
||||
# # for intermediate outputs
|
||||
# if self.aux_loss:
|
||||
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
|
||||
|
||||
# # for encoder output
|
||||
# if hs_enc is not None:
|
||||
# # prepare intermediate outputs
|
||||
# interm_coord = ref_enc[-1]
|
||||
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
|
||||
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
|
||||
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
|
||||
|
||||
return out
|
||||
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
|
||||
|
||||
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
|
||||
def build_groundingdino(args):
|
||||
|
||||
backbone = build_backbone(args)
|
||||
transformer = build_transformer(args)
|
||||
|
||||
dn_labelbook_size = args.dn_labelbook_size
|
||||
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
|
||||
sub_sentence_present = args.sub_sentence_present
|
||||
|
||||
model = GroundingDINO(
|
||||
backbone,
|
||||
transformer,
|
||||
num_queries=args.num_queries,
|
||||
aux_loss=True,
|
||||
iter_update=True,
|
||||
query_dim=4,
|
||||
num_feature_levels=args.num_feature_levels,
|
||||
nheads=args.nheads,
|
||||
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
|
||||
two_stage_type=args.two_stage_type,
|
||||
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
|
||||
two_stage_class_embed_share=args.two_stage_class_embed_share,
|
||||
num_patterns=args.num_patterns,
|
||||
dn_number=0,
|
||||
dn_box_noise_scale=args.dn_box_noise_scale,
|
||||
dn_label_noise_ratio=args.dn_label_noise_ratio,
|
||||
dn_labelbook_size=dn_labelbook_size,
|
||||
text_encoder_type=args.text_encoder_type,
|
||||
sub_sentence_present=sub_sentence_present,
|
||||
max_text_len=args.max_text_len,
|
||||
)
|
||||
|
||||
return model
|
||||