Compare commits
971 Commits
psyche/boa
...
v4.2.9.dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6af5a22d3a | ||
|
|
8cf4321010 | ||
|
|
aa7f2b096a | ||
|
|
bf0824b56d | ||
|
|
056c56d322 | ||
|
|
afc6f83d72 | ||
|
|
c776ac3af2 | ||
|
|
b7b3683bef | ||
|
|
fb26b6824a | ||
|
|
63d8ad912f | ||
|
|
bbd7d7fc17 | ||
|
|
6507a78182 | ||
|
|
22f46517f4 | ||
|
|
45596e1f94 | ||
|
|
6de0dbe854 | ||
|
|
011827fa29 | ||
|
|
fc6d244071 | ||
|
|
cd3da886d6 | ||
|
|
c013c55d92 | ||
|
|
cd3dd7db0d | ||
|
|
1fdcce9429 | ||
|
|
181e40926d | ||
|
|
c62ede5878 | ||
|
|
a2ad5f1a9a | ||
|
|
ff74a5356f | ||
|
|
92dc30dace | ||
|
|
3af577b210 | ||
|
|
d0464330f7 | ||
|
|
dd3ef4a80f | ||
|
|
0ced891944 | ||
|
|
10a5452df9 | ||
|
|
cb97969bbc | ||
|
|
71e742e238 | ||
|
|
fadd20fb8e | ||
|
|
01b9ca78e4 | ||
|
|
2baf825f34 | ||
|
|
1fa8048509 | ||
|
|
a000ad75f6 | ||
|
|
aefb2339bb | ||
|
|
a4f8671f86 | ||
|
|
73530ba54f | ||
|
|
685eb9927d | ||
|
|
ee57302fc3 | ||
|
|
c1fb9cdb93 | ||
|
|
aa6d441552 | ||
|
|
25d8d4c2e9 | ||
|
|
427ea6da5c | ||
|
|
d9f4266630 | ||
|
|
96f6e9e683 | ||
|
|
f10248e3f5 | ||
|
|
6a21f5fde1 | ||
|
|
ff20dd509a | ||
|
|
78a59b5b78 | ||
|
|
46bfbbbc87 | ||
|
|
a6d73d0773 | ||
|
|
6578e8bef8 | ||
|
|
0596d25e07 | ||
|
|
86e8ce9139 | ||
|
|
5aa2957da4 | ||
|
|
82f0cb2c8c | ||
|
|
fa48145cbc | ||
|
|
6d1edc330d | ||
|
|
97c0d3f6be | ||
|
|
a79a25ad63 | ||
|
|
6a8ceef404 | ||
|
|
3539670d93 | ||
|
|
c54bc32ef6 | ||
|
|
fee293e289 | ||
|
|
747eef9ccc | ||
|
|
7d2df399ed | ||
|
|
68fad5cdcc | ||
|
|
b4d656c203 | ||
|
|
3136d89d52 | ||
|
|
27e829b955 | ||
|
|
e03e870d5b | ||
|
|
9465ff450b | ||
|
|
92906a9575 | ||
|
|
77f206abe4 | ||
|
|
44a3f61580 | ||
|
|
0a2afed08b | ||
|
|
9b3b961105 | ||
|
|
9b1828e1aa | ||
|
|
5101873f49 | ||
|
|
c612f18114 | ||
|
|
7e400d876f | ||
|
|
677dddcfc9 | ||
|
|
0792b9175e | ||
|
|
e4829f80af | ||
|
|
bb760f3eb4 | ||
|
|
388c65287b | ||
|
|
12cd41e05c | ||
|
|
7765c03949 | ||
|
|
3daa80c57f | ||
|
|
5dbbef4ebd | ||
|
|
db33b3f7b5 | ||
|
|
8ffcf2a6be | ||
|
|
2e7ae6a07e | ||
|
|
fea1711f0c | ||
|
|
2a3546db97 | ||
|
|
285c266612 | ||
|
|
426ad54c53 | ||
|
|
fc75f7919f | ||
|
|
6c6b1aaff6 | ||
|
|
c319d653ac | ||
|
|
d887e474e7 | ||
|
|
da7b52d6ba | ||
|
|
b5aa308593 | ||
|
|
0b7ceb3bb6 | ||
|
|
3a70cefda2 | ||
|
|
4b609251e1 | ||
|
|
0839eac0f7 | ||
|
|
5f2a7feeee | ||
|
|
982535eb92 | ||
|
|
0c2b8edc8d | ||
|
|
f78f4ca25f | ||
|
|
d6b3e6c07d | ||
|
|
071ff8e74a | ||
|
|
1ea8aafca1 | ||
|
|
533dd221f8 | ||
|
|
2b325c6683 | ||
|
|
3845b1b3e6 | ||
|
|
cea7890a67 | ||
|
|
c38fe8025d | ||
|
|
f1de95349c | ||
|
|
2950775fa7 | ||
|
|
cb293fd7ac | ||
|
|
43b3fab6be | ||
|
|
d4b0dbce49 | ||
|
|
137b810669 | ||
|
|
c172657324 | ||
|
|
97c966b04f | ||
|
|
7178fc6253 | ||
|
|
4adb2eabf5 | ||
|
|
9f2c815e13 | ||
|
|
1435557d1d | ||
|
|
96abf687f6 | ||
|
|
636d9a7209 | ||
|
|
3b36eb0223 | ||
|
|
388c97bff0 | ||
|
|
b1cb018695 | ||
|
|
df78dd7953 | ||
|
|
0dc344a22e | ||
|
|
350d7f6f14 | ||
|
|
11059ee2d4 | ||
|
|
c90d3f3bb9 | ||
|
|
7f6d439fd1 | ||
|
|
783a78f069 | ||
|
|
0ff031950d | ||
|
|
d7e8f3d756 | ||
|
|
4668ea449b | ||
|
|
30d318d021 | ||
|
|
de96f97e5f | ||
|
|
57c0a2dfb1 | ||
|
|
cd4e464bde | ||
|
|
49e48c3eb7 | ||
|
|
edd3b3bce9 | ||
|
|
f8bfb66108 | ||
|
|
3b6a76cbf3 | ||
|
|
e0b60e4320 | ||
|
|
2159319035 | ||
|
|
b170fc232e | ||
|
|
594da60f2f | ||
|
|
6a432f6518 | ||
|
|
eb8eacfec6 | ||
|
|
c8d04d42e2 | ||
|
|
d39c9de81e | ||
|
|
a27d39b9ff | ||
|
|
6b385614f0 | ||
|
|
3ae7250ef7 | ||
|
|
a42d0ce1d2 | ||
|
|
d9131f7563 | ||
|
|
bdce958f29 | ||
|
|
3c86f1e979 | ||
|
|
894b8a29b9 | ||
|
|
8436a44973 | ||
|
|
f9f9ec3688 | ||
|
|
5a98d7a1f6 | ||
|
|
f9bc96e497 | ||
|
|
56350ff5dc | ||
|
|
6c1139340c | ||
|
|
641b1a7e6f | ||
|
|
674a3f462f | ||
|
|
2283186d3a | ||
|
|
340af1fe50 | ||
|
|
9378656d78 | ||
|
|
5f0413e222 | ||
|
|
c3e47515b1 | ||
|
|
5dcef6fa0d | ||
|
|
31ac02cd93 | ||
|
|
ab16976084 | ||
|
|
8e2b7845e1 | ||
|
|
3973bce342 | ||
|
|
f63847a504 | ||
|
|
07e3529948 | ||
|
|
03e1c60694 | ||
|
|
d766ed71fc | ||
|
|
ae68ef142a | ||
|
|
20f55768c4 | ||
|
|
c4fad4456e | ||
|
|
78f5ec44ad | ||
|
|
e14ba86942 | ||
|
|
d4e7720f6b | ||
|
|
a3f0e7e1cb | ||
|
|
30a696c476 | ||
|
|
66d6c64e16 | ||
|
|
d15be9b57c | ||
|
|
e5da902fd0 | ||
|
|
fc558094c2 | ||
|
|
ad9312e989 | ||
|
|
8e1a70b008 | ||
|
|
17f88cd5ad | ||
|
|
298f1919fa | ||
|
|
4d20cc11d4 | ||
|
|
14f249a2f0 | ||
|
|
b9746a6c2c | ||
|
|
94f298a6f4 | ||
|
|
8d3a8178da | ||
|
|
cad4212fe8 | ||
|
|
cff28dfaa9 | ||
|
|
70d7509fcc | ||
|
|
cf83af7a27 | ||
|
|
5c5a405c0f | ||
|
|
0208e4b232 | ||
|
|
e940754795 | ||
|
|
dc9fa1a735 | ||
|
|
08591fbf6d | ||
|
|
74db71bb5d | ||
|
|
60dbe798a5 | ||
|
|
0e676605fe | ||
|
|
3f781016f6 | ||
|
|
17cd2f6b02 | ||
|
|
99102a1b34 | ||
|
|
8d72e7d9e8 | ||
|
|
0b6b6f97ad | ||
|
|
fb2f6382b1 | ||
|
|
1ddea87c35 | ||
|
|
ea02323095 | ||
|
|
49733091c7 | ||
|
|
cf833fd6e2 | ||
|
|
ba5cf07ab8 | ||
|
|
d15321a373 | ||
|
|
de597a5eb4 | ||
|
|
e5f5cbdf5c | ||
|
|
7d4342bbff | ||
|
|
7f8a1d8d20 | ||
|
|
65353ac1e1 | ||
|
|
7f9a31ca4a | ||
|
|
592eb2886c | ||
|
|
c220dd8987 | ||
|
|
a263beb0d5 | ||
|
|
46b7c510eb | ||
|
|
f405e472ea | ||
|
|
7bdfd3ef5f | ||
|
|
778ee2c679 | ||
|
|
e70339ff3e | ||
|
|
88c57a9750 | ||
|
|
137252128b | ||
|
|
d4297b1345 | ||
|
|
6059bc7b47 | ||
|
|
c3ff3eb51f | ||
|
|
0b7751c413 | ||
|
|
d7f1c30624 | ||
|
|
3f4d7dbeea | ||
|
|
19b6ae2907 | ||
|
|
769f96ff9f | ||
|
|
fdaf75faa4 | ||
|
|
1380bb7ae6 | ||
|
|
9483c8cc29 | ||
|
|
2ef8a8cf5a | ||
|
|
d296ec1932 | ||
|
|
444ad3dae1 | ||
|
|
8cdcc71378 | ||
|
|
e8bc06cfd3 | ||
|
|
67a0a024e9 | ||
|
|
bd2c46c267 | ||
|
|
5acb27f350 | ||
|
|
7271b12d0f | ||
|
|
4a79467a33 | ||
|
|
5501bb87a3 | ||
|
|
561610e296 | ||
|
|
b76609ef18 | ||
|
|
070b78501b | ||
|
|
50df4f4ab6 | ||
|
|
9bbf430125 | ||
|
|
384a90958a | ||
|
|
0e4a25b029 | ||
|
|
4a44e171fd | ||
|
|
9bc57a6f59 | ||
|
|
4341ed7ab4 | ||
|
|
97ce72c542 | ||
|
|
a2c78a57a7 | ||
|
|
044a713dc9 | ||
|
|
b8479c5fe2 | ||
|
|
4e5d056824 | ||
|
|
118278b372 | ||
|
|
8e8c255f3f | ||
|
|
1575bee401 | ||
|
|
249bbfc883 | ||
|
|
3993ae410f | ||
|
|
edf040e3d2 | ||
|
|
66fd077ee7 | ||
|
|
b93462ebb6 | ||
|
|
aae60d0cdc | ||
|
|
d4da00e607 | ||
|
|
0c539ff00b | ||
|
|
5983cbf26c | ||
|
|
c513d6e3af | ||
|
|
9d57c0e631 | ||
|
|
a1923a8966 | ||
|
|
d988e18731 | ||
|
|
51008da2dd | ||
|
|
6ccc1f5672 | ||
|
|
4a556f84e0 | ||
|
|
2f21a2220d | ||
|
|
91a420b13e | ||
|
|
c27da3581b | ||
|
|
961dfbce93 | ||
|
|
df39c825ae | ||
|
|
3f6ee1b7a4 | ||
|
|
908e504a6f | ||
|
|
f2fa41afc5 | ||
|
|
440ff40ad5 | ||
|
|
5c15458e15 | ||
|
|
be5b474f1e | ||
|
|
cee178c2b6 | ||
|
|
27657f8b7a | ||
|
|
e0cde3815a | ||
|
|
09d0421de4 | ||
|
|
47b94d563c | ||
|
|
0b5d20c9f0 | ||
|
|
80e7e1293a | ||
|
|
3a82b0cbc1 | ||
|
|
a27cbc13b6 | ||
|
|
a8f962eb3f | ||
|
|
7f40d23f19 | ||
|
|
918354cd9d | ||
|
|
eef9278ee6 | ||
|
|
2c32e2e5c1 | ||
|
|
6f05654db5 | ||
|
|
1d31b6902f | ||
|
|
5a7d615e64 | ||
|
|
1dbf9e4ed4 | ||
|
|
5dcc6ee203 | ||
|
|
84a4e6ae3f | ||
|
|
f283bfd68f | ||
|
|
6e5ff7b79c | ||
|
|
7c3800d03f | ||
|
|
941db90518 | ||
|
|
0d9ecf0f90 | ||
|
|
9c77023a11 | ||
|
|
b55378c63c | ||
|
|
946c2a49ab | ||
|
|
b823c31ec6 | ||
|
|
ec6361e5cb | ||
|
|
0c26d28278 | ||
|
|
c5172d4c5a | ||
|
|
89de04775e | ||
|
|
b4c3c940b5 | ||
|
|
aee2aad959 | ||
|
|
5ca48a8a5f | ||
|
|
1806aa187b | ||
|
|
7824cb7a1a | ||
|
|
9807a896f4 | ||
|
|
19866f057d | ||
|
|
ec4eae3c9c | ||
|
|
bea0cba038 | ||
|
|
48ee75af9c | ||
|
|
929c593d2f | ||
|
|
221f32eca7 | ||
|
|
c3acc15e8b | ||
|
|
1b653278fc | ||
|
|
cc9062ee46 | ||
|
|
91c0feb0ad | ||
|
|
ae60292ac8 | ||
|
|
a6ca17b19d | ||
|
|
6a4a5ece74 | ||
|
|
9b81860307 | ||
|
|
5f4a3928d2 | ||
|
|
b703884763 | ||
|
|
32da98ab8f | ||
|
|
bd5a85bf70 | ||
|
|
d045f24014 | ||
|
|
2aad3f89c3 | ||
|
|
dd54d19f00 | ||
|
|
0ed6591d8c | ||
|
|
712e090134 | ||
|
|
8fc2a1d1cf | ||
|
|
cc15c1593e | ||
|
|
9997d3abda | ||
|
|
031471e785 | ||
|
|
2e860c6791 | ||
|
|
d071a9e17d | ||
|
|
ed53d33321 | ||
|
|
382bc6d978 | ||
|
|
dab42e258c | ||
|
|
81556410bb | ||
|
|
1f2dfd473c | ||
|
|
8f0f51be2c | ||
|
|
7179e250ed | ||
|
|
5bec091fd6 | ||
|
|
2c5896cb0c | ||
|
|
93ff252dc0 | ||
|
|
ac52224455 | ||
|
|
4087cad23d | ||
|
|
e936b1ff8f | ||
|
|
b7f9c5e221 | ||
|
|
fc5467150e | ||
|
|
4dcab357a0 | ||
|
|
695e464255 | ||
|
|
9999b60c3b | ||
|
|
e7df53e260 | ||
|
|
844590a571 | ||
|
|
9622beaa0d | ||
|
|
007e2553a8 | ||
|
|
15ad4e3f5e | ||
|
|
be5094fcb4 | ||
|
|
a20a861680 | ||
|
|
396d0a4bc0 | ||
|
|
ca9314e077 | ||
|
|
4b848798e7 | ||
|
|
083bcbc77d | ||
|
|
e8cdc9ae62 | ||
|
|
8abfa759a4 | ||
|
|
f6a324b633 | ||
|
|
f083be9391 | ||
|
|
091e2fb751 | ||
|
|
d8539daf1f | ||
|
|
7ec059f5fa | ||
|
|
4f2ecdefd2 | ||
|
|
e8891a1988 | ||
|
|
37d2607f34 | ||
|
|
0e7b10d3d9 | ||
|
|
1f85888638 | ||
|
|
c1f9a129fa | ||
|
|
7ccc5ba398 | ||
|
|
5e1a6ae334 | ||
|
|
3f6cf638f9 | ||
|
|
46e062a828 | ||
|
|
cc3a0b5d6c | ||
|
|
775479ab7b | ||
|
|
6b9e0e6d63 | ||
|
|
83a5c87f5e | ||
|
|
84fde74331 | ||
|
|
a517e29b39 | ||
|
|
ccceba7565 | ||
|
|
5fc7a03669 | ||
|
|
8864ad1b50 | ||
|
|
f2989885fb | ||
|
|
19c66e5c76 | ||
|
|
8a6690a57c | ||
|
|
2cc60f253a | ||
|
|
cb69872dd3 | ||
|
|
ba66d7c9a6 | ||
|
|
9fe727c9f8 | ||
|
|
58c656224f | ||
|
|
c51253f5f6 | ||
|
|
6c1d1588fc | ||
|
|
95d6183a6c | ||
|
|
f4c9facdaf | ||
|
|
a274e6f165 | ||
|
|
154e3e6f64 | ||
|
|
2c3ac972e5 | ||
|
|
2e2e072b0b | ||
|
|
9d8dd2bf66 | ||
|
|
9047f6db30 | ||
|
|
5ab345ee63 | ||
|
|
d8a83acd3a | ||
|
|
1f58e5756b | ||
|
|
744acb8f07 | ||
|
|
ae7228d821 | ||
|
|
99d8b3a7bf | ||
|
|
3fbe65bbcf | ||
|
|
f5d879d8e7 | ||
|
|
cbc5a4f8e6 | ||
|
|
37ac7d8ed5 | ||
|
|
bee3fa339d | ||
|
|
c171fe2b96 | ||
|
|
1fa8032fdb | ||
|
|
5c0676bcc2 | ||
|
|
cefd9a027c | ||
|
|
1bce156de1 | ||
|
|
cd4f63f2fd | ||
|
|
3c7140cbf3 | ||
|
|
b71ba63b5a | ||
|
|
d540e2c0d3 | ||
|
|
d79fafc5f5 | ||
|
|
9e93fa2092 | ||
|
|
392e9b4882 | ||
|
|
231e5ec94a | ||
|
|
e5bb6f9693 | ||
|
|
da7dee44c6 | ||
|
|
83144f4fe3 | ||
|
|
c451f52ea3 | ||
|
|
8a2c78f2e1 | ||
|
|
bcc78bde9b | ||
|
|
054bb6fe0a | ||
|
|
4f4aa6d92e | ||
|
|
eac51ac6f5 | ||
|
|
9f349a7c0a | ||
|
|
918afa5b15 | ||
|
|
eb1113f95c | ||
|
|
4f4ba7b462 | ||
|
|
2298be0e6b | ||
|
|
63494dfca7 | ||
|
|
36a1d39454 | ||
|
|
a6f6d5c400 | ||
|
|
e85f221aca | ||
|
|
d4797e37dc | ||
|
|
3e7923d072 | ||
|
|
a85d69ce3d | ||
|
|
96db006c99 | ||
|
|
8ca57d03d8 | ||
|
|
6c404ce5f8 | ||
|
|
584e07182b | ||
|
|
f787e9acf6 | ||
|
|
5a24b89e54 | ||
|
|
9b482e2a4f | ||
|
|
df4dbe2d57 | ||
|
|
713bd11177 | ||
|
|
182571df4b | ||
|
|
29bfe492b6 | ||
|
|
3fb4e3050c | ||
|
|
39c7ec3cd9 | ||
|
|
26bfbdec7f | ||
|
|
7a3eaa8da9 | ||
|
|
599db7296f | ||
|
|
042aab4295 | ||
|
|
24f298283f | ||
|
|
68dac6349d | ||
|
|
b675fc19e8 | ||
|
|
659019cfd6 | ||
|
|
dcd61e1f82 | ||
|
|
f5c99b1488 | ||
|
|
810be3e1d4 | ||
|
|
60d754d1df | ||
|
|
bd07c86db9 | ||
|
|
bcbf8b6bd8 | ||
|
|
356661459b | ||
|
|
deb917825e | ||
|
|
15415c6d85 | ||
|
|
76b0380b5f | ||
|
|
2d58754789 | ||
|
|
9cdf1f599c | ||
|
|
268be97ba0 | ||
|
|
a9014673a0 | ||
|
|
d36c43a10f | ||
|
|
54a5c4e482 | ||
|
|
5e09a244e3 | ||
|
|
88648dca1a | ||
|
|
8840df2b00 | ||
|
|
af159acbdf | ||
|
|
471719bbbe | ||
|
|
b126f2ffd5 | ||
|
|
9938f12ef0 | ||
|
|
982c266073 | ||
|
|
5c37391883 | ||
|
|
ddeafc6833 | ||
|
|
41b2d5d013 | ||
|
|
29d6f48901 | ||
|
|
d5c9f4e47f | ||
|
|
24d73387d8 | ||
|
|
e0d3927265 | ||
|
|
e5f7c2a9b7 | ||
|
|
b0760710d5 | ||
|
|
764accc921 | ||
|
|
6a01fce9c1 | ||
|
|
9c732ac3b1 | ||
|
|
b70891c661 | ||
|
|
4dbf851741 | ||
|
|
6c927a9fd4 | ||
|
|
096f001634 | ||
|
|
4837e578b2 | ||
|
|
1e547ef912 | ||
|
|
f6b8970bd1 | ||
|
|
29325a7214 | ||
|
|
8ecf72838d | ||
|
|
c3ab8a6aa8 | ||
|
|
1931aa3e70 | ||
|
|
d3d8055055 | ||
|
|
476b0a0403 | ||
|
|
f66584713c | ||
|
|
33624fc2fa | ||
|
|
41c3e73a3c | ||
|
|
97553a7de2 | ||
|
|
12ba15bfa9 | ||
|
|
09d1e190e7 | ||
|
|
8eb5d08499 | ||
|
|
9be6acde7d | ||
|
|
5f83bb0069 | ||
|
|
b138882abc | ||
|
|
0cd7cdb52e | ||
|
|
1d8b7e2bcf | ||
|
|
6461f4758d | ||
|
|
3189ab6863 | ||
|
|
3f9a674d4b | ||
|
|
587f59b25b | ||
|
|
4952eada87 | ||
|
|
581029ebaa | ||
|
|
42d68780de | ||
|
|
28032a2f80 | ||
|
|
e381e021e9 | ||
|
|
641af64f93 | ||
|
|
a7b83c8b5b | ||
|
|
4cc41e0188 | ||
|
|
442fc02429 | ||
|
|
9a4d075074 | ||
|
|
17ff8196cb | ||
|
|
68f993998a | ||
|
|
7da6120b39 | ||
|
|
6cd40965c4 | ||
|
|
408a1d6dbb | ||
|
|
0b0abfbe8f | ||
|
|
cc96dcf0ed | ||
|
|
2604fd9fde | ||
|
|
140670d00e | ||
|
|
70233fae5d | ||
|
|
6f457a6c4c | ||
|
|
5c319f5356 | ||
|
|
991a04f090 | ||
|
|
c39fa75113 | ||
|
|
f7863e17ce | ||
|
|
7c526390ed | ||
|
|
2cff20f87a | ||
|
|
90ec757802 | ||
|
|
4b85dfcefe | ||
|
|
21deefdc41 | ||
|
|
857d74bbfe | ||
|
|
fd7a635777 | ||
|
|
af9110e964 | ||
|
|
a61209206b | ||
|
|
e05cc62e5f | ||
|
|
4d4f921a4e | ||
|
|
98db8f395b | ||
|
|
f465a956a3 | ||
|
|
9edb02d7ef | ||
|
|
6c4cf58a31 | ||
|
|
08993c0d29 | ||
|
|
4f8a4b0f22 | ||
|
|
a743f3c9b5 | ||
|
|
217fe40d99 | ||
|
|
b76bf50b93 | ||
|
|
571ba87e13 | ||
|
|
f27b6e2b44 | ||
|
|
981475a624 | ||
|
|
27ac61a4fb | ||
|
|
675ffc2757 | ||
|
|
44b21f10f1 | ||
|
|
c6d49e8b1f | ||
|
|
e6a512aa86 | ||
|
|
c3a6a6fb22 | ||
|
|
b9dc3460ba | ||
|
|
63581ec980 | ||
|
|
08b1feeed7 | ||
|
|
f5cfdcf32d | ||
|
|
e78fb428f0 | ||
|
|
31e270e32c | ||
|
|
b5832768dc | ||
|
|
4ce64b69cb | ||
|
|
5a9173f766 | ||
|
|
0bb7ed44f6 | ||
|
|
332bc9da5b | ||
|
|
08def3da95 | ||
|
|
daf899f9c4 | ||
|
|
13fb2d1f49 | ||
|
|
95dde802ea | ||
|
|
fca119773b | ||
|
|
0193267a53 | ||
|
|
b4cf78a95d | ||
|
|
73386826d6 | ||
|
|
9f448fecb7 | ||
|
|
bcd1483a14 | ||
|
|
e206890e25 | ||
|
|
0a7048f650 | ||
|
|
e8ecf5e155 | ||
|
|
33e8604b57 | ||
|
|
cec7399366 | ||
|
|
bdae81e429 | ||
|
|
67c32f3d6c | ||
|
|
94d64b8a78 | ||
|
|
fa3c0c81b3 | ||
|
|
66547b99c1 | ||
|
|
328e58be4c | ||
|
|
18f89ed5ed | ||
|
|
5701c79fab | ||
|
|
2da9f913f3 | ||
|
|
6b10b59abe | ||
|
|
918f77bce0 | ||
|
|
f170697ebe | ||
|
|
556c6a1d84 | ||
|
|
aca2a2fa13 | ||
|
|
ff6398f7d8 | ||
|
|
cf996472b9 | ||
|
|
156d14c349 | ||
|
|
86f705bf48 | ||
|
|
1fd9631f2d | ||
|
|
2227a2357f | ||
|
|
58e7ab157d | ||
|
|
8d16fa6a49 | ||
|
|
55e810efa3 | ||
|
|
2755316021 | ||
|
|
6525f18610 | ||
|
|
2ad13ac7eb | ||
|
|
693a3eaff5 | ||
|
|
ffca792d5b | ||
|
|
86a92bb6b5 | ||
|
|
171a4e6d80 | ||
|
|
e3a75a8adf | ||
|
|
ee7503ce13 | ||
|
|
8500bac3ca | ||
|
|
310719eb4c | ||
|
|
e8e24822ec | ||
|
|
c57a7afb87 | ||
|
|
84d028898c | ||
|
|
ed0174fbc6 | ||
|
|
9e582563eb | ||
|
|
faa88f72bf | ||
|
|
0d69a31df0 | ||
|
|
daa5a88eb2 | ||
|
|
5b84e117b2 | ||
|
|
eb257d2d28 | ||
|
|
5810cee6c9 | ||
|
|
eef88d1f83 | ||
|
|
78f6850fc0 | ||
|
|
bd8890be11 | ||
|
|
adf1a977ea | ||
|
|
e1509bcb45 | ||
|
|
edcaf8287d | ||
|
|
39bd30f2a0 | ||
|
|
102b47190f | ||
|
|
269fe2e3bb | ||
|
|
b32aa1c77f | ||
|
|
6656544ed5 | ||
|
|
4c75b93410 | ||
|
|
5be0de967d | ||
|
|
f8e27b837b | ||
|
|
47414be1e6 | ||
|
|
74cef38bcf | ||
|
|
bb876b8d4e | ||
|
|
ba747373db | ||
|
|
95661c8b21 | ||
|
|
e5d9ca013e | ||
|
|
4166c756ce | ||
|
|
4f0dfbd34d | ||
|
|
b70ac88684 | ||
|
|
24609da6ab | ||
|
|
524647b1f1 | ||
|
|
cf1af94f53 | ||
|
|
2a9fdc6314 | ||
|
|
46c632e7cc | ||
|
|
653f63ae71 | ||
|
|
8a9e2f57a4 | ||
|
|
31949ed2f2 | ||
|
|
3657285b1b | ||
|
|
e4b5975305 | ||
|
|
b59825edc0 | ||
|
|
25788f6869 | ||
|
|
0ccb304b8b | ||
|
|
ca5a4ee59d | ||
|
|
4fdefe58c7 | ||
|
|
9870f5a96f | ||
|
|
c296ae8cfe | ||
|
|
17493f4ae0 | ||
|
|
2503dca813 | ||
|
|
cb61ef9bb1 | ||
|
|
1831ed620f | ||
|
|
c385e76356 | ||
|
|
ff1972fbb3 | ||
|
|
c4b3405bfa | ||
|
|
ab2548c0cd | ||
|
|
dc2a3363b0 | ||
|
|
d7a5fe2805 | ||
|
|
4e49689d46 | ||
|
|
ca8441a32f | ||
|
|
44284d671c | ||
|
|
e89de1d5b7 | ||
|
|
6db63349f8 | ||
|
|
7f6f892533 | ||
|
|
d1bbd0cf80 | ||
|
|
bd73b6b2af | ||
|
|
0d40a7d865 | ||
|
|
c2f6b80246 | ||
|
|
80f5f8210a | ||
|
|
b7383cc0e5 | ||
|
|
2172e4d292 | ||
|
|
ab0bfa709a | ||
|
|
6af659b1da | ||
|
|
db664afc49 | ||
|
|
b99a53e64e | ||
|
|
5f4ce6fda3 | ||
|
|
93e95ce53f | ||
|
|
2997f0a1f8 | ||
|
|
40b262bcc2 | ||
|
|
a26f050cbb | ||
|
|
94b5b2a467 | ||
|
|
b4519ea61f | ||
|
|
7f7ce291b5 | ||
|
|
aeb53563ff | ||
|
|
e8d2e2330e | ||
|
|
4c6b9ce7c9 | ||
|
|
87a2221d72 | ||
|
|
76aa6bdf05 | ||
|
|
416d29fb83 | ||
|
|
0c1994d682 | ||
|
|
19c00241c6 | ||
|
|
633bbb4e85 | ||
|
|
a221ab2fb6 | ||
|
|
0279a27f66 | ||
|
|
54aef4959c | ||
|
|
4017609b91 | ||
|
|
cb0bffedd5 | ||
|
|
1fd2a91ccd | ||
|
|
c323a760a5 | ||
|
|
9d1fcba415 | ||
|
|
075e0405f9 | ||
|
|
bf6066d834 | ||
|
|
5635f65ee9 | ||
|
|
6317cf8ef9 | ||
|
|
9e1daf06f7 | ||
|
|
e1a718b512 | ||
|
|
cbce89162b | ||
|
|
b46b20210d | ||
|
|
8e89157a83 | ||
|
|
ca21996a97 | ||
|
|
62aa064e56 | ||
|
|
7c975f0d00 | ||
|
|
8107884c8d | ||
|
|
a2f49ef7c1 | ||
|
|
e2e47fd606 | ||
|
|
c098edc6b2 | ||
|
|
bc1d9748ce | ||
|
|
7b8e25f525 | ||
|
|
db52f5606f | ||
|
|
de39c5ed21 | ||
|
|
d014dc94fd | ||
|
|
39e804d0f8 | ||
|
|
154e8f6e78 | ||
|
|
2d31b82e60 | ||
|
|
8f934747f3 | ||
|
|
4352341a00 | ||
|
|
d7e0ec52ff | ||
|
|
1072b74c0e | ||
|
|
46dc8c6641 | ||
|
|
a8bc6ab5b1 | ||
|
|
bd91bd4a84 | ||
|
|
8756a6b8c3 | ||
|
|
2e0cebb571 | ||
|
|
c3a8184431 | ||
|
|
ffa39d74b3 | ||
|
|
f9d3966ea2 | ||
|
|
7cee4e42a7 | ||
|
|
071c7c7c7e | ||
|
|
818045f678 | ||
|
|
7edefbefff | ||
|
|
29efab70b7 | ||
|
|
ac6adc392a | ||
|
|
a2ef5d56ee | ||
|
|
13f3560e55 | ||
|
|
c4bd60e00f | ||
|
|
54eda9163c | ||
|
|
582f384fff | ||
|
|
a43211e650 | ||
|
|
6cb0581b0d | ||
|
|
845d77916e | ||
|
|
f18431a999 | ||
|
|
5060bf2f62 | ||
|
|
7854d913b2 | ||
|
|
890a3ce32a | ||
|
|
fb4b3f3350 | ||
|
|
d166b08b6a | ||
|
|
5266e9e682 | ||
|
|
d0265e21b0 | ||
|
|
3126e8e49a | ||
|
|
9e3412d776 | ||
|
|
4a09cc57be | ||
|
|
5ab36e0433 | ||
|
|
d2bf3629bf | ||
|
|
d9b217d908 | ||
|
|
2847f1b5ac | ||
|
|
bc30850f3a | ||
|
|
7668dc68a0 | ||
|
|
ea449f5a0a | ||
|
|
5a1ed99ca1 | ||
|
|
3a2707ac02 | ||
|
|
ce5b1103ed | ||
|
|
fd91b83d86 | ||
|
|
a0a54348e8 | ||
|
|
43b3e242b0 | ||
|
|
4e8dcb7a1a | ||
|
|
3cb13d6288 | ||
|
|
4f01c0f2d3 | ||
|
|
87eb018380 | ||
|
|
5003e5d763 | ||
|
|
e92af52fb8 | ||
|
|
5f0fe3c8a9 | ||
|
|
339dddd018 | ||
|
|
1b359b55cb | ||
|
|
d0435575c1 | ||
|
|
58f3072b91 | ||
|
|
9e7b470189 | ||
|
|
42356ec866 | ||
|
|
1748848b7b | ||
|
|
5772965f09 | ||
|
|
e046e60e1c | ||
|
|
9a1420280e | ||
|
|
f9c61f1b6c | ||
|
|
a8cc5caf96 | ||
|
|
930ff559e4 | ||
|
|
473f4cc1c3 | ||
|
|
78d2b1b650 | ||
|
|
39e10d894c | ||
|
|
e16faa6370 | ||
|
|
83a86abce2 | ||
|
|
0c56d4a581 | ||
|
|
97a7f51721 | ||
|
|
710dc6b487 | ||
|
|
2ef3b49a79 | ||
|
|
3f79467f7b | ||
|
|
2c2ec8f0bc | ||
|
|
79e35bd0d3 | ||
|
|
137202b77c | ||
|
|
03e22c257b | ||
|
|
ae6d4fbc78 | ||
|
|
cd1bc1595a | ||
|
|
0583101c1c | ||
|
|
f866b49255 | ||
|
|
b7c6c63005 | ||
|
|
95e9f5323b | ||
|
|
6b0ca88177 | ||
|
|
7ad32dcad2 | ||
|
|
81991e072b | ||
|
|
cec345cb5c | ||
|
|
608cbe3f5c | ||
|
|
7905a46ca4 | ||
|
|
38343917f8 | ||
|
|
9f088d1bf5 | ||
|
|
fd8d1c12d4 | ||
|
|
d623bd429b | ||
|
|
499e4d4fde | ||
|
|
e961dd1dec | ||
|
|
7e00526999 | ||
|
|
3a9dda9177 | ||
|
|
bd8ae5d896 | ||
|
|
87e96e1be2 | ||
|
|
0bc60378d3 | ||
|
|
9cc852cf7f | ||
|
|
0428ce73a9 | ||
|
|
d0d2955992 | ||
|
|
d868d5d584 | ||
|
|
ab775726b7 | ||
|
|
650902dc29 | ||
|
|
7b5d4935b4 | ||
|
|
ecbff2aa44 | ||
|
|
0ce6ec634d | ||
|
|
d09999736c | ||
|
|
a405f14ea2 | ||
|
|
9d3739244f | ||
|
|
534528b85a | ||
|
|
114320ee69 | ||
|
|
6161aa73af | ||
|
|
1ab20f43c8 | ||
|
|
9328c17ded | ||
|
|
c1c8e55e8e | ||
|
|
504a42fe61 | ||
|
|
29c8ddfb88 | ||
|
|
95079dc7d4 | ||
|
|
2a1514272f | ||
|
|
59ce9cf41c | ||
|
|
e6abea7bc5 | ||
|
|
c335f92345 | ||
|
|
c1afe35704 | ||
|
|
e4813f800a |
2
.github/workflows/python-checks.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff
|
||||
run: pip install ruff==0.6.0
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
|
||||
@@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
# Invoke in Docker
|
||||
|
||||
- Ensure that Docker can use the GPU on your system
|
||||
- This documentation assumes Linux, but should work similarly under Windows with WSL2
|
||||
First things first:
|
||||
|
||||
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
|
||||
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
|
||||
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
||||
|
||||
## Quickstart :lightning:
|
||||
## Quickstart
|
||||
|
||||
No `docker compose`, no persistence, just a simple one-liner using the official images:
|
||||
No `docker compose`, no persistence, single command, using the official images:
|
||||
|
||||
**CUDA:**
|
||||
**CUDA (NVIDIA GPU):**
|
||||
|
||||
```bash
|
||||
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
||||
```
|
||||
|
||||
**ROCm:**
|
||||
**ROCm (AMD GPU):**
|
||||
|
||||
```bash
|
||||
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
||||
@@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
|
||||
|
||||
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
||||
|
||||
> [!TIP]
|
||||
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
|
||||
### Data persistence
|
||||
|
||||
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
|
||||
|
||||
```bash
|
||||
docker run --volume /some/local/path:/invokeai {...etc...}
|
||||
```
|
||||
|
||||
`/some/local/path/invokeai` will contain all your data.
|
||||
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
|
||||
|
||||
## Customize the container
|
||||
|
||||
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
|
||||
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
@@ -38,11 +48,14 @@ cp .env.sample .env
|
||||
|
||||
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
||||
|
||||
>[!TIP]
|
||||
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
|
||||
|
||||
## Docker setup in detail
|
||||
|
||||
#### Linux
|
||||
|
||||
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
||||
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
||||
3. Ensure docker daemon is able to access the GPU.
|
||||
@@ -98,25 +111,7 @@ GPU_DRIVER=cuda
|
||||
|
||||
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||
|
||||
## Even More Customizing!
|
||||
---
|
||||
|
||||
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||
|
||||
### Reconfigure the runtime directory
|
||||
|
||||
Can be used to download additional models from the supported model list
|
||||
|
||||
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
|
||||
|
||||
```yaml
|
||||
command:
|
||||
- invokeai-configure
|
||||
- --yes
|
||||
```
|
||||
|
||||
Or install models:
|
||||
|
||||
```yaml
|
||||
command:
|
||||
- invokeai-model-install
|
||||
```
|
||||
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
set -eu
|
||||
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
scriptdir=$(dirname $(readlink -f "$0"))
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
@@ -31,6 +32,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
@@ -63,7 +66,12 @@ class ApiDependencies:
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||
def initialize(
|
||||
config: InvokeAIAppConfig,
|
||||
event_handler_id: int,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
logger: Logger = logger,
|
||||
) -> None:
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
|
||||
@@ -74,6 +82,7 @@ class ApiDependencies:
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
style_presets_folder = config.style_presets_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
@@ -84,7 +93,7 @@ class ApiDependencies:
|
||||
board_images = BoardImagesService()
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
events = FastAPIEventService(event_handler_id, loop=loop)
|
||||
bulk_download = BulkDownloadService()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
@@ -109,6 +118,8 @@ class ApiDependencies:
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -134,6 +145,8 @@ class ApiDependencies:
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -218,9 +218,8 @@ async def get_image_workflow(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
@images_router.get(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -231,6 +230,18 @@ async def get_image_workflow(
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
@images_router.head(
|
||||
"/i/{image_name}/full",
|
||||
operation_id="get_image_full_head",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> Response:
|
||||
@@ -242,6 +253,7 @@ async def get_image_full(
|
||||
content = f.read()
|
||||
response = Response(content, media_type="image/png")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
@@ -6,7 +6,7 @@ import pathlib
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
@@ -430,13 +430,11 @@ async def delete_model_image(
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
@@ -451,8 +449,9 @@ async def install_model(
|
||||
- model/name:fp16:path/to/model.safetensors
|
||||
- model/name::path/to/model.safetensors
|
||||
|
||||
`config` is an optional dict containing model configuration values that will override
|
||||
the ones that are probed automatically.
|
||||
`config` is a ModelRecordChanges object. Fields in this object will override
|
||||
the ones that are probed automatically. Pass an empty object to accept
|
||||
all the defaults.
|
||||
|
||||
`access_token` is an optional access token for use with Urls that require
|
||||
authentication.
|
||||
@@ -737,7 +736,7 @@ async def convert_model(
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
raw_model.save_pretrained(convert_path)
|
||||
raw_model.save_pretrained(convert_path) # type: ignore
|
||||
assert convert_path.exists()
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
@@ -750,12 +749,12 @@ async def convert_model(
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
convert_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
config=ModelRecordChanges(
|
||||
name=original_name,
|
||||
description=model_config.description,
|
||||
hash=model_config.hash,
|
||||
source=model_config.source,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
@@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
PruneResult,
|
||||
@@ -105,6 +106,19 @@ async def cancel_by_batch_ids(
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_origin",
|
||||
operation_id="cancel_by_origin",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_origin(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
origin: str = Query(description="The origin to cancel all queue items for"),
|
||||
) -> CancelByOriginResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/clear",
|
||||
operation_id="clear",
|
||||
|
||||
274
invokeai/app/api/routers/style_presets.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
InvalidPresetImportDataError,
|
||||
PresetData,
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordWithImage,
|
||||
StylePresetWithoutId,
|
||||
UnsupportedFileTypeError,
|
||||
parse_presets_from_file,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetFormData(BaseModel):
|
||||
name: str = Field(description="Preset name")
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
type: PresetType = Field(description="Preset type")
|
||||
|
||||
|
||||
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="get_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def get_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to get"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Gets a style preset"""
|
||||
try:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
|
||||
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
|
||||
except StylePresetNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Style preset not found")
|
||||
|
||||
|
||||
@style_presets_router.patch(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="update_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def update_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
style_preset_id: str = Path(description="The id of the style preset to update"),
|
||||
data: str = Form(description="The data of the style preset to update"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Updates a style preset"""
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
else:
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
|
||||
|
||||
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
||||
style_preset_id=style_preset_id, changes=changes
|
||||
)
|
||||
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.delete(
|
||||
"/i/{style_preset_id}",
|
||||
operation_id="delete_style_preset",
|
||||
)
|
||||
async def delete_style_preset(
|
||||
style_preset_id: str = Path(description="The style preset to delete"),
|
||||
) -> None:
|
||||
"""Deletes a style preset"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
||||
except StylePresetImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/",
|
||||
operation_id="create_style_preset",
|
||||
responses={
|
||||
200: {"model": StylePresetRecordWithImage},
|
||||
},
|
||||
)
|
||||
async def create_style_preset(
|
||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
||||
data: str = Form(description="The data of the style preset to create"),
|
||||
) -> StylePresetRecordWithImage:
|
||||
"""Creates a style preset"""
|
||||
|
||||
try:
|
||||
parsed_data = json.loads(data)
|
||||
validated_data = StylePresetFormData(**parsed_data)
|
||||
|
||||
name = validated_data.name
|
||||
type = validated_data.type
|
||||
positive_prompt = validated_data.positive_prompt
|
||||
negative_prompt = validated_data.negative_prompt
|
||||
|
||||
except pydantic.ValidationError:
|
||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||
|
||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
|
||||
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
|
||||
|
||||
if image is not None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
|
||||
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/",
|
||||
operation_id="list_style_presets",
|
||||
responses={
|
||||
200: {"model": list[StylePresetRecordWithImage]},
|
||||
},
|
||||
)
|
||||
async def list_style_presets() -> list[StylePresetRecordWithImage]:
|
||||
"""Gets a page of style presets"""
|
||||
style_presets_with_image: list[StylePresetRecordWithImage] = []
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
|
||||
for preset in style_presets:
|
||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
|
||||
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
|
||||
style_presets_with_image.append(style_preset_with_image)
|
||||
|
||||
return style_presets_with_image
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/i/{style_preset_id}/image",
|
||||
operation_id="get_style_preset_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The style preset image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The style preset image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_style_preset_image(
|
||||
style_preset_id: str = Path(description="The id of the style preset image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=style_preset_id + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@style_presets_router.get(
|
||||
"/export",
|
||||
operation_id="export_style_presets",
|
||||
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
|
||||
status_code=200,
|
||||
)
|
||||
async def export_style_presets():
|
||||
# Create an in-memory stream to store the CSV data
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write the header
|
||||
writer.writerow(["name", "prompt", "negative_prompt"])
|
||||
|
||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
|
||||
|
||||
for preset in style_presets:
|
||||
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
|
||||
|
||||
csv_data = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return Response(
|
||||
content=csv_data,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
|
||||
)
|
||||
|
||||
|
||||
@style_presets_router.post(
|
||||
"/import",
|
||||
operation_id="import_style_presets",
|
||||
)
|
||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
||||
try:
|
||||
style_presets = await parse_presets_from_file(file)
|
||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
||||
except InvalidPresetImportDataError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except UnsupportedFileTypeError as e:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
@@ -30,6 +30,7 @@ from invokeai.app.api.routers import (
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
@@ -55,11 +56,13 @@ mimetypes.add_type("text/css", ".css")
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
@@ -106,6 +109,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
@@ -184,8 +188,6 @@ def invoke_api() -> None:
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
|
||||
@@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||
@@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
model_state_dict=state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||
|
||||
@@ -21,6 +21,8 @@ from controlnet_aux import (
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
@@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return color_map
|
||||
|
||||
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "LiheYoung/depth-anything-large-hf",
|
||||
"base": "LiheYoung/depth-anything-base-hf",
|
||||
"small": "LiheYoung/depth-anything-small-hf",
|
||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.2",
|
||||
version="1.1.3",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small", description="The size of the depth model to use"
|
||||
default="small_v2", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
def load_depth_anything(model_path: Path):
|
||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||
) as depth_anything_detector:
|
||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||
depth_map = depth_anything_detector.generate_depth(image)
|
||||
|
||||
# Resizing to user target specified size
|
||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||
depth_map = depth_map.resize((self.resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
@@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
blur_tensor[blur_tensor < 0] = 0.0
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@@ -36,9 +37,10 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
@@ -53,6 +55,19 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -314,9 +329,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
@@ -330,10 +346,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
cond_list, context, device, dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
uncond_list, context, device, dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
@@ -341,14 +357,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if isinstance(cfg_scale, list):
|
||||
@@ -455,6 +471,65 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return controlnet_data
|
||||
|
||||
@staticmethod
|
||||
def parse_controlnet_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
# Normalize control_input to a list.
|
||||
control_list: list[ControlField]
|
||||
if isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list):
|
||||
control_list = control_input
|
||||
elif control_input is None:
|
||||
control_list = []
|
||||
else:
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
for control_info in control_list:
|
||||
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
ext_manager.add_extension(
|
||||
ControlNetExt(
|
||||
model=model,
|
||||
image=context.images.get_pil(control_info.image.image_name),
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_t2i_adapter_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
|
||||
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||
if isinstance(t2i_adapters, T2IAdapterField):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||
weight=t2i_adapter_field.weight,
|
||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -664,7 +739,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
return mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
@@ -707,12 +782,157 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
if os.environ.get("USE_MODULAR_DENOISE", False):
|
||||
return self._new_invoke(context)
|
||||
else:
|
||||
return self._old_invoke(context)
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
# TODO: old backend, remove
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
seed=seed,
|
||||
device=device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
### preview
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
### cfg rescale
|
||||
if self.cfg_rescale_multiplier > 0:
|
||||
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
|
||||
|
||||
### freeu
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
### lora
|
||||
if self.unet.loras:
|
||||
for lora_field in self.unet.loras:
|
||||
ext_manager.add_extension(
|
||||
LoRAExt(
|
||||
node_context=context,
|
||||
model_id=lora_field.lora,
|
||||
weight=lora_field.weight,
|
||||
)
|
||||
)
|
||||
### seamless
|
||||
if self.unet.seamless_axes:
|
||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||
|
||||
### inpaint
|
||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||
if unet_config.variant == ModelVariantType.Inpaint:
|
||||
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||
elif mask is not None:
|
||||
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||
|
||||
# Initialize context for modular denoise
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
attention_processor_cls=CustomAttnProcessor2_0,
|
||||
),
|
||||
unet=None,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
# later should be smth like:
|
||||
# for extension_field in self.extensions:
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(denoise_ctx),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(unet, cached_weights),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.detach().to("cpu")
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||
# We invert the mask here for compatibility with the old backend implementation.
|
||||
if mask is not None:
|
||||
mask = 1 - mask
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
@@ -755,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
@@ -788,7 +1008,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
@@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -134,6 +135,7 @@ class FieldDescriptions:
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
spandrel_image_to_image_model = "Image-to-Image model"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
@@ -240,6 +242,31 @@ class ConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BoundingBoxField(BaseModel):
|
||||
"""A bounding box primitive value."""
|
||||
|
||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||
|
||||
score: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
||||
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_coords(self):
|
||||
if self.x_min > self.x_max:
|
||||
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||
if self.y_min > self.y_max:
|
||||
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||
return self
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
|
||||
100
invokeai/app/invocations/grounding_dino.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
|
||||
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
|
||||
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
|
||||
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"grounding_dino",
|
||||
title="Grounding DINO (Text Prompt Object Detection)",
|
||||
tags=["prompt", "object detection"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GroundingDinoInvocation(BaseInvocation):
|
||||
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
|
||||
|
||||
# Reference:
|
||||
# - https://arxiv.org/pdf/2303.05499
|
||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
|
||||
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
|
||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
detection_threshold: float = InputField(
|
||||
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
default=0.3,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
||||
# The model expects a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
detections = self._detect(
|
||||
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||
)
|
||||
|
||||
# Convert detections to BoundingBoxCollectionOutput.
|
||||
bounding_boxes: list[BoundingBoxField] = []
|
||||
for detection in detections:
|
||||
bounding_boxes.append(
|
||||
BoundingBoxField(
|
||||
x_min=detection.box.xmin,
|
||||
x_max=detection.box.xmax,
|
||||
y_min=detection.box.ymin,
|
||||
y_max=detection.box.ymax,
|
||||
score=detection.score,
|
||||
)
|
||||
)
|
||||
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
||||
|
||||
@staticmethod
|
||||
def _load_grounding_dino(model_path: Path):
|
||||
grounding_dino_pipeline = pipeline(
|
||||
model=str(model_path),
|
||||
task="zero-shot-object-detection",
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||
|
||||
def _detect(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
labels: list[str],
|
||||
threshold: float = 0.3,
|
||||
) -> list[DetectionResult]:
|
||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||
# actually makes a difference.
|
||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||
|
||||
with context.models.load_remote_model(
|
||||
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||
) as detector:
|
||||
assert isinstance(detector, GroundingDinoPipeline)
|
||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
||||
@@ -6,13 +6,19 @@ import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ColorField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation_output("canvas_v2_mask_and_crop_output")
|
||||
class CanvasV2MaskAndCropOutput(ImageOutput):
|
||||
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
||||
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_v2_mask_and_crop",
|
||||
title="Canvas V2 Mask and Crop",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
|
||||
source_image: ImageField | None = InputField(
|
||||
default=None,
|
||||
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
|
||||
)
|
||||
generated_image: ImageField = InputField(description="The image to apply the mask to")
|
||||
mask: ImageField = InputField(description="The mask to apply")
|
||||
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||
|
||||
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||
mask_array = numpy.array(mask)
|
||||
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||
if self.mask_blur > 0:
|
||||
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
return ImageOps.invert(mask.convert("L"))
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||
|
||||
if self.source_image:
|
||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||
source_image = context.images.get_pil(self.source_image.image_name)
|
||||
source_image.paste(generated_image, (0, 0), mask)
|
||||
image_dto = context.images.save(image=source_image)
|
||||
else:
|
||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||
generated_image.putalpha(mask)
|
||||
image_dto = context.images.save(image=generated_image)
|
||||
|
||||
# bbox = image.getbbox()
|
||||
# image = image.crop(bbox)
|
||||
|
||||
return CanvasV2MaskAndCropOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
offset_x=0,
|
||||
offset_y=0,
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
height=mask.shape[1],
|
||||
width=mask.shape[2],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"tensor_mask_to_image",
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
|
||||
mask: TensorField = InputField(description="The mask tensor to convert.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
mask_np = (mask.float() * 255).byte().cpu().numpy()
|
||||
|
||||
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||
image_dto = context.images.save(image=mask_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
@@ -469,3 +470,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region BoundingBox
|
||||
|
||||
|
||||
@invocation_output("bounding_box_output")
|
||||
class BoundingBoxOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single bounding box"""
|
||||
|
||||
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
||||
|
||||
|
||||
@invocation_output("bounding_box_collection_output")
|
||||
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of bounding boxes"""
|
||||
|
||||
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
|
||||
|
||||
|
||||
@invocation(
|
||||
"bounding_box",
|
||||
title="Bounding Box",
|
||||
tags=["primitives", "segmentation", "collection", "bounding box"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BoundingBoxInvocation(BaseInvocation):
|
||||
"""Create a bounding box manually by supplying box coordinates"""
|
||||
|
||||
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
|
||||
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
|
||||
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
|
||||
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
|
||||
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
|
||||
return BoundingBoxOutput(bounding_box=bounding_box)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
161
invokeai/app/invocations/segment_anything.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
|
||||
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
||||
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||
"segment-anything-base": "facebook/sam-vit-base",
|
||||
"segment-anything-large": "facebook/sam-vit-large",
|
||||
"segment-anything-huge": "facebook/sam-vit-huge",
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything",
|
||||
title="Segment Anything",
|
||||
tags=["prompt", "segmentation"],
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SegmentAnythingInvocation(BaseInvocation):
|
||||
"""Runs a Segment Anything Model."""
|
||||
|
||||
# Reference:
|
||||
# - https://arxiv.org/pdf/2304.02643
|
||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
|
||||
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||
apply_polygon_refinement: bool = InputField(
|
||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||
default=True,
|
||||
)
|
||||
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||
default="all",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if len(self.bounding_boxes) == 0:
|
||||
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||
else:
|
||||
masks = self._segment(context=context, image=image_pil)
|
||||
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||
|
||||
# masks contains bool values, so we merge them via max-reduce.
|
||||
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||
|
||||
mask_tensor_name = context.tensors.save(combined_mask)
|
||||
height, width = combined_mask.shape
|
||||
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||
|
||||
@staticmethod
|
||||
def _load_sam_model(model_path: Path):
|
||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
model_path,
|
||||
local_files_only=True,
|
||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# model, and figure out how to make it work in the pipeline.
|
||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
assert isinstance(sam_model, SamModel)
|
||||
|
||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
assert isinstance(sam_processor, SamProcessor)
|
||||
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
def _segment(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
image: Image.Image,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
# Convert the bounding boxes to the SAM input format.
|
||||
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||
|
||||
masks = self._process_masks(masks)
|
||||
if self.apply_polygon_refinement:
|
||||
masks = self._apply_polygon_refinement(masks)
|
||||
|
||||
return masks
|
||||
|
||||
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||
"""
|
||||
assert masks.dtype == torch.bool
|
||||
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||
masks, _ = masks.max(dim=1)
|
||||
# Split the first dimension into a list of masks.
|
||||
return list(masks.cpu().unbind(dim=0))
|
||||
|
||||
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""Apply polygon refinement to the masks.
|
||||
|
||||
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||
- Smooth the edges of the mask slightly.
|
||||
- Ensure that each mask consists of a single closed polygon
|
||||
- Removes small mask pieces.
|
||||
- Removes holes from the mask.
|
||||
"""
|
||||
# Convert tensor masks to np masks.
|
||||
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||
|
||||
# Apply polygon refinement.
|
||||
for idx, mask in enumerate(np_masks):
|
||||
shape = mask.shape
|
||||
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
np_masks[idx] = mask
|
||||
|
||||
# Convert np masks back to tensor masks.
|
||||
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||
|
||||
return masks
|
||||
|
||||
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||
"""Filter the detected masks based on the specified mask filter."""
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
|
||||
if self.mask_filter == "all":
|
||||
return masks
|
||||
elif self.mask_filter == "largest":
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||
return [masks[max_score_idx]]
|
||||
else:
|
||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||
253
invokeai/app/invocations/spandrel_image_to_image.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
image_to_image_model: ModelIdentifierField = InputField(
|
||||
title="Image-to-Image Model",
|
||||
description=FieldDescriptions.spandrel_image_to_image_model,
|
||||
ui_type=UIType.SpandrelImageToImageModel,
|
||||
)
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
||||
return Tile(
|
||||
coords=TBLR(
|
||||
top=tile.coords.top * scale,
|
||||
bottom=tile.coords.bottom * scale,
|
||||
left=tile.coords.left * scale,
|
||||
right=tile.coords.right * scale,
|
||||
),
|
||||
overlap=TBLR(
|
||||
top=tile.overlap.top * scale,
|
||||
bottom=tile.overlap.bottom * scale,
|
||||
left=tile.overlap.left * scale,
|
||||
right=tile.overlap.right * scale,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def upscale_image(
|
||||
cls,
|
||||
image: Image.Image,
|
||||
tile_size: int,
|
||||
spandrel_model: SpandrelImageToImageModel,
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> Image.Image:
|
||||
# Compute the image tiles.
|
||||
if tile_size > 0:
|
||||
min_overlap = 20
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=image.height,
|
||||
image_width=image.width,
|
||||
tile_height=tile_size,
|
||||
tile_width=tile_size,
|
||||
min_overlap=min_overlap,
|
||||
)
|
||||
else:
|
||||
# No tiling. Generate a single tile that covers the entire image.
|
||||
min_overlap = 0
|
||||
tiles = [
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
|
||||
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||
)
|
||||
]
|
||||
|
||||
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
|
||||
# over tiles left-to-right, top-to-bottom.
|
||||
tiles = sorted(tiles, key=lambda x: x.coords.left)
|
||||
tiles = sorted(tiles, key=lambda x: x.coords.top)
|
||||
|
||||
# Prepare input image for inference.
|
||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||
|
||||
# Scale the tiles for re-assembling the final image.
|
||||
scale = spandrel_model.scale
|
||||
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
|
||||
|
||||
# Prepare the output tensor.
|
||||
_, channels, height, width = image_tensor.shape
|
||||
output_tensor = torch.zeros(
|
||||
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on each tile.
|
||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||
# Exit early if the invocation has been canceled.
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# Extract the current tile from the input tensor.
|
||||
input_tile = image_tensor[
|
||||
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on the tile.
|
||||
output_tile = spandrel_model.run(input_tile)
|
||||
|
||||
# Convert the output tile into the output tensor's format.
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
output_tile = output_tile.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
output_tile = output_tile.permute(1, 2, 0)
|
||||
output_tile = output_tile.clamp(0, 1)
|
||||
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
||||
|
||||
# Merge the output tile into the output tensor.
|
||||
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
||||
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
||||
# it seems unnecessary, but we may find a need in the future.
|
||||
top_overlap = scaled_tile.overlap.top // 2
|
||||
left_overlap = scaled_tile.overlap.left // 2
|
||||
output_tensor[
|
||||
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
||||
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
||||
:,
|
||||
] = output_tile[top_overlap:, left_overlap:, :]
|
||||
|
||||
# Convert the output tensor to a PIL image.
|
||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
return pil_image
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Upscale the image
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"spandrel_image_to_image_autoscale",
|
||||
title="Image-to-Image (Autoscale)",
|
||||
tags=["upscale"],
|
||||
category="upscale",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
|
||||
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
||||
target_width = int(image.width * self.scale)
|
||||
target_height = int(image.height * self.scale)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# First pass of upscaling. Note: `pil_image` will be mutated.
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
|
||||
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
|
||||
# to be considered an upscale model.
|
||||
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
|
||||
|
||||
if is_upscale_model:
|
||||
# This is an upscale model, so we should keep upscaling until we reach the target size.
|
||||
iterations = 1
|
||||
while pil_image.width < target_width or pil_image.height < target_height:
|
||||
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
iterations += 1
|
||||
|
||||
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
|
||||
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
|
||||
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
|
||||
# we should never reach this limit.
|
||||
if iterations >= 5:
|
||||
context.logger.warning(
|
||||
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
|
||||
)
|
||||
break
|
||||
else:
|
||||
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
|
||||
# to be the same as the processed image size.
|
||||
|
||||
# The output size is now the size of the processed image.
|
||||
target_width = pil_image.width
|
||||
target_height = pil_image.height
|
||||
|
||||
# Warn the user if they requested a scale greater than 1.
|
||||
if self.scale > 1:
|
||||
context.logger.warning(
|
||||
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
|
||||
)
|
||||
|
||||
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
|
||||
# in the final resize
|
||||
if self.fit_to_multiple_of_8:
|
||||
target_width = int(target_width // 8 * 8)
|
||||
target_height = int(target_height // 8 * 8)
|
||||
|
||||
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
|
||||
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
|
||||
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -175,6 +175,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
|
||||
# - How much overlap 'context' to provide for each denoising step.
|
||||
# - How much overlap to use during merging/blending.
|
||||
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=latent_height,
|
||||
image_width=latent_width,
|
||||
@@ -218,7 +222,8 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
device=unet.device,
|
||||
dtype=unet.dtype,
|
||||
latent_height=latent_tile_height,
|
||||
latent_width=latent_tile_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
|
||||
@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
custom_nodes_dir: Path to directory for custom nodes.
|
||||
style_presets_dir: Path to directory for style presets.
|
||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
@@ -153,6 +154,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
|
||||
|
||||
# LOGGING
|
||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||
@@ -300,6 +302,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def style_presets_path(self) -> Path:
|
||||
"""Path to the style presets directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.style_presets_dir)
|
||||
|
||||
@property
|
||||
def convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
|
||||
@@ -88,6 +88,7 @@ class QueueItemEventBase(QueueEventBase):
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@@ -95,8 +96,6 @@ class InvocationEventBase(QueueItemEventBase):
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
@@ -114,6 +113,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -147,6 +147,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -184,6 +185,7 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -216,6 +218,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -253,6 +256,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
@@ -279,12 +283,14 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
origin=enqueue_result.batch.origin,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
|
||||
@@ -1,46 +1,44 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._queue = asyncio.Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._loop = loop
|
||||
|
||||
# We need to store a reference to the task so it doesn't get GC'd
|
||||
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
|
||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.remove)
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
event = await self._queue.get()
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
# Leave the payloads as live pydantic models
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.image_files.image_files_common import (
|
||||
@@ -20,18 +19,12 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: Path
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__cache: dict[Path, PILImageType] = {}
|
||||
self.__cache_ids = Queue[Path]()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||
# Validate required output folders at launch
|
||||
self.__validate_storage_folders()
|
||||
@@ -103,7 +96,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if image_path.exists():
|
||||
send2trash(image_path)
|
||||
image_path.unlink()
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
@@ -111,7 +104,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||
|
||||
if thumbnail_path.exists():
|
||||
send2trash(thumbnail_path)
|
||||
thumbnail_path.unlink()
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
@@ -61,6 +63,8 @@ class InvocationServices:
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -85,3 +89,5 @@ class InvocationServices:
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
|
||||
@@ -2,7 +2,6 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
||||
@@ -70,7 +69,7 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
@@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
|
||||
This keeps the model in its current location.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
@@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
:param source: String source
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
:param config: Optional ModelRecordChanges object. Any fields in this object
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record as described in `import_model()`.
|
||||
:param access_token: Optional access token for remote sources.
|
||||
@@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@ import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
from typing import Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
@@ -133,8 +134,9 @@ class ModelInstallJob(BaseModel):
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
config_in: ModelRecordChanges = Field(
|
||||
default_factory=ModelRecordChanges,
|
||||
description="Configuration information (e.g. 'description') to apply to model.",
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
|
||||
@@ -163,26 +163,27 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
config = config or ModelRecordChanges()
|
||||
if not config.source:
|
||||
config.source = model_path.resolve().as_posix()
|
||||
config.source_type = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
info: AnyModelConfig = ModelProbe.probe(
|
||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
||||
) # type: ignore
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
dest_path = (
|
||||
@@ -204,7 +205,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
@@ -216,7 +217,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
|
||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||
if similar_jobs:
|
||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||
@@ -318,16 +319,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = self._app_config.models_path / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config = ModelRecordChanges(
|
||||
name=model_name,
|
||||
description=stanza.get("description"),
|
||||
)
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
config.config_path = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@@ -500,11 +502,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
job.config_in.source = str(job.source)
|
||||
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
job.config_in.source_api_response = job.source_metadata.api_response
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
@@ -639,11 +641,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return new_path
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
@@ -674,11 +676,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_local_model(
|
||||
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
|
||||
) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
@@ -686,7 +690,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
if source.access_token is None:
|
||||
@@ -702,7 +706,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
@@ -717,7 +721,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: HFModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges],
|
||||
) -> ModelInstallJob:
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
@@ -730,7 +734,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
install_job = ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
source_metadata=metadata,
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
|
||||
@@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@@ -66,10 +67,16 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
source: Optional[str] = Field(description="original source of the model", default=None)
|
||||
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
|
||||
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
||||
@@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
"""Cancels all queue items with the given batch origin"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
|
||||
@@ -77,6 +77,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of this batch.")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
@@ -195,6 +196,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
origin: str | None = Field(default=None, description="The origin of this queue item. ")
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
@@ -294,6 +296,7 @@ class SessionQueueStatus(BaseModel):
|
||||
class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
origin: str | None = Field(..., description="The origin of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
@@ -328,6 +331,12 @@ class CancelByBatchIDsResult(BaseModel):
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByOriginResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
@@ -433,6 +442,7 @@ class SessionQueueValueToInsert(NamedTuple):
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@@ -453,6 +463,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
batch.origin, # origin
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND origin == ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, origin)
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.origin == origin:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByOriginResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
@@ -541,7 +578,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id
|
||||
queue_id,
|
||||
origin
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -621,7 +659,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
SELECT status, count(*), origin
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -633,6 +671,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -641,6 +680,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
|
||||
@@ -16,6 +16,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -49,6 +51,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_12(app_config=config))
|
||||
migrator.register_migration(build_migration_13())
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.register_migration(build_migration_15())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration14Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_style_presets(cursor)
|
||||
|
||||
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store style presets."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS style_presets (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
preset_data TEXT NOT NULL,
|
||||
type TEXT NOT NULL DEFAULT "user",
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS style_presets
|
||||
AFTER UPDATE
|
||||
ON style_presets FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_14() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 13 to 14..
|
||||
|
||||
This migration does the following:
|
||||
- Create the table used to store style presets.
|
||||
"""
|
||||
migration_14 = Migration(
|
||||
from_version=13,
|
||||
to_version=14,
|
||||
callback=Migration14Callback(),
|
||||
)
|
||||
|
||||
return migration_14
|
||||
@@ -0,0 +1,31 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration15Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_origin_col(cursor)
|
||||
|
||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `origin` column to the session queue table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||
|
||||
|
||||
def build_migration_15() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 14 to 15.
|
||||
|
||||
This migration does the following:
|
||||
- Adds `origin` column to the session queue table.
|
||||
"""
|
||||
migration_15 = Migration(
|
||||
from_version=14,
|
||||
to_version=15,
|
||||
callback=Migration15Callback(),
|
||||
)
|
||||
|
||||
return migration_15
|
||||
|
After Width: | Height: | Size: 98 KiB |
|
After Width: | Height: | Size: 138 KiB |
|
After Width: | Height: | Size: 122 KiB |
|
After Width: | Height: | Size: 123 KiB |
|
After Width: | Height: | Size: 160 KiB |
|
After Width: | Height: | Size: 146 KiB |
|
After Width: | Height: | Size: 119 KiB |
|
After Width: | Height: | Size: 117 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 46 KiB |
|
After Width: | Height: | Size: 79 KiB |
|
After Width: | Height: | Size: 156 KiB |
|
After Width: | Height: | Size: 141 KiB |
|
After Width: | Height: | Size: 96 KiB |
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 88 KiB |
|
After Width: | Height: | Size: 107 KiB |
|
After Width: | Height: | Size: 132 KiB |
@@ -0,0 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class StylePresetImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
"""Retrieves a style preset image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
"""Gets the internal path to a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
"""Gets the URL to fetch a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
"""Saves a style preset image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset image."""
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
class StylePresetImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class StylePresetImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message: str = "Style preset image file not deleted"):
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_common import (
|
||||
StylePresetImageFileDeleteException,
|
||||
StylePresetImageFileNotFoundException,
|
||||
StylePresetImageFileSaveException,
|
||||
)
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
|
||||
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, style_preset_images_folder: Path):
|
||||
self._style_preset_images_folder = style_preset_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, style_preset_id: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
|
||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileSaveException from e
|
||||
|
||||
def get_path(self, style_preset_id: str) -> Path:
|
||||
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
|
||||
if style_preset.type is PresetType.Default:
|
||||
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
|
||||
path = default_images_dir / (style_preset.name + ".png")
|
||||
else:
|
||||
path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, style_preset_id: str) -> str | None:
|
||||
path = self.get_path(style_preset_id)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
|
||||
|
||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
||||
url += f"?{uuid_string()}"
|
||||
|
||||
return url
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
path = self.get_path(style_preset_id)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise StylePresetImageFileNotFoundException
|
||||
|
||||
path.unlink()
|
||||
|
||||
except StylePresetImageFileNotFoundException as e:
|
||||
raise StylePresetImageFileNotFoundException from e
|
||||
except Exception as e:
|
||||
raise StylePresetImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)
|
||||
@@ -0,0 +1,146 @@
|
||||
[
|
||||
{
|
||||
"name": "Photography (General)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Studio Lighting)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Landscape)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Portrait)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Photography (Black and White)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
|
||||
"negative_prompt": "painting, digital art. sketch, colour+"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Architectural Visualization",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
|
||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Fantasy)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Sci-Fi)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Character)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
|
||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Concept Art (Painterly)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
|
||||
"negative_prompt": "photo. smooth. border. frame"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Environment Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
|
||||
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Interior Design (Visualization)",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
|
||||
"negative_prompt": "photo, distorted. sketch."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Product Rendering",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
|
||||
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Sketch",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
|
||||
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Line Art",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
|
||||
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Anime",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
|
||||
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Illustration",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
|
||||
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Vehicles",
|
||||
"type": "default",
|
||||
"preset_data": {
|
||||
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
|
||||
"negative_prompt": "sketch. digital art. greyscale. painting"
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
|
||||
|
||||
class StylePresetRecordsStorageBase(ABC):
|
||||
"""Base class for style preset storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Get style preset by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
"""Creates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
"""Creates many style presets."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
"""Updates a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
"""Deletes a style preset."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -0,0 +1,139 @@
|
||||
import codecs
|
||||
import csv
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from fastapi import UploadFile
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
class StylePresetNotFoundError(Exception):
|
||||
"""Raised when a style preset is not found"""
|
||||
|
||||
|
||||
class PresetData(BaseModel, extra="forbid"):
|
||||
positive_prompt: str = Field(description="Positive prompt")
|
||||
negative_prompt: str = Field(description="Negative prompt")
|
||||
|
||||
|
||||
PresetDataValidator = TypeAdapter(PresetData)
|
||||
|
||||
|
||||
class PresetType(str, Enum, metaclass=MetaEnum):
|
||||
User = "user"
|
||||
Default = "default"
|
||||
Project = "project"
|
||||
|
||||
|
||||
class StylePresetChanges(BaseModel, extra="forbid"):
|
||||
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
||||
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
||||
type: Optional[PresetType] = Field(description="The updated type of the style preset")
|
||||
|
||||
|
||||
class StylePresetWithoutId(BaseModel):
|
||||
name: str = Field(description="The name of the style preset.")
|
||||
preset_data: PresetData = Field(description="The preset data")
|
||||
type: PresetType = Field(description="The type of style preset")
|
||||
|
||||
|
||||
class StylePresetRecordDTO(StylePresetWithoutId):
|
||||
id: str = Field(description="The style preset ID.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
||||
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
|
||||
return StylePresetRecordDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
||||
|
||||
|
||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
||||
image: Optional[str] = Field(description="The path for image")
|
||||
|
||||
|
||||
class StylePresetImportRow(BaseModel):
|
||||
name: str = Field(min_length=1, description="The name of the preset.")
|
||||
positive_prompt: str = Field(
|
||||
default="",
|
||||
description="The positive prompt for the preset.",
|
||||
validation_alias=AliasChoices("positive_prompt", "prompt"),
|
||||
)
|
||||
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
|
||||
|
||||
|
||||
StylePresetImportList = list[StylePresetImportRow]
|
||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(ValueError):
|
||||
"""Raised when an unsupported file type is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPresetImportDataError(ValueError):
|
||||
"""Raised when invalid preset import data is encountered"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
|
||||
"""Parses style presets from a file. The file must be a CSV or JSON file.
|
||||
|
||||
If CSV, the file must have the following columns:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
If JSON, the file must be a list of objects with the following keys:
|
||||
- name
|
||||
- prompt (or positive_prompt)
|
||||
- negative_prompt
|
||||
|
||||
Args:
|
||||
file (UploadFile): The file to parse.
|
||||
|
||||
Returns:
|
||||
list[StylePresetWithoutId]: The parsed style presets.
|
||||
|
||||
Raises:
|
||||
UnsupportedFileTypeError: If the file type is not supported.
|
||||
InvalidPresetImportDataError: If the data in the file is invalid.
|
||||
"""
|
||||
if file.content_type not in ["text/csv", "application/json"]:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if file.content_type == "text/csv":
|
||||
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
|
||||
data = list(csv_reader)
|
||||
else: # file.content_type == "application/json":
|
||||
json_data = await file.read()
|
||||
data = json.loads(json_data)
|
||||
|
||||
try:
|
||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
||||
|
||||
style_presets: list[StylePresetWithoutId] = []
|
||||
|
||||
for imported in imported_presets:
|
||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
||||
style_presets.append(style_preset)
|
||||
except pydantic.ValidationError as e:
|
||||
if file.content_type == "text/csv":
|
||||
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
else: # file.content_type == "application/json":
|
||||
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
||||
raise InvalidPresetImportDataError(msg) from e
|
||||
finally:
|
||||
file.file.close()
|
||||
|
||||
return style_presets
|
||||
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_style_presets()
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
name,
|
||||
preset_data,
|
||||
type
|
||||
)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
style_preset_id,
|
||||
style_preset.name,
|
||||
style_preset.preset_data.model_dump_json(),
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET name = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.name, style_preset_id),
|
||||
)
|
||||
|
||||
# Change the preset data for a style preset
|
||||
if changes.preset_data is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET preset_data = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
self._cursor.execute(main_query, (type,))
|
||||
else:
|
||||
self._cursor.execute(main_query)
|
||||
|
||||
rows = self._cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Next, parse and create the default style presets
|
||||
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
for preset in presets:
|
||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
||||
self.create(style_preset)
|
||||
@@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
"""Gets the URL for a style preset image"""
|
||||
pass
|
||||
|
||||
@@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
||||
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_openapi_func(
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": invocation_output_map_properties,
|
||||
"properties": dict(sorted(invocation_output_map_properties.items())),
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
width=518,
|
||||
height=518,
|
||||
resize_target=False,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=14,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||
return image
|
||||
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
np_image = np_image[:, :, ::-1] / 255.0
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
|
||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||
depth_map = Image.fromarray(depth_map)
|
||||
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
return depth_map
|
||||
@@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import DepthEstimationPipeline
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class DepthAnythingPipeline(RawModel):
|
||||
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||
for Invoke's Model Management System"""
|
||||
|
||||
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
||||
self._pipeline = pipeline
|
||||
|
||||
def generate_depth(self, image: Image.Image) -> Image.Image:
|
||||
depth_map = self._pipeline(image)["depth"]
|
||||
assert isinstance(depth_map, Image.Image)
|
||||
return depth_map
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
@@ -1,145 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module."""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
if self.bn:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block."""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups = 1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size = size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
@@ -1,183 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
|
||||
|
||||
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size=None):
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||
super(DPTHead, self).__init__()
|
||||
|
||||
self.nclass = nclass
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
for out_channel in out_channels
|
||||
]
|
||||
)
|
||||
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
if nclass > 1:
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.scratch.output_conv1 = nn.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
else:
|
||||
x = x[0]
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
out = self.scratch.output_conv2(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPT_DINOv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
out_channels,
|
||||
encoder="vitl",
|
||||
use_bn=False,
|
||||
use_clstoken=False,
|
||||
):
|
||||
super(DPT_DINOv2, self).__init__()
|
||||
|
||||
assert encoder in ["vits", "vitb", "vitl"]
|
||||
|
||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||
# if use_local:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# torchhub_path / "facebookresearch_dinov2_main",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# source="local",
|
||||
# pretrained=False,
|
||||
# )
|
||||
# else:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# "facebookresearch/dinov2",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
# )
|
||||
|
||||
self.pretrained = torch.hub.load(
|
||||
"facebookresearch/dinov2",
|
||||
"dinov2_{:}14".format(encoder),
|
||||
)
|
||||
|
||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||
|
||||
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
||||
|
||||
patch_h, patch_w = h // 14, w // 14
|
||||
|
||||
depth = self.depth_head(features, patch_h, patch_w)
|
||||
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
||||
depth = F.relu(depth)
|
||||
|
||||
return depth.squeeze(1)
|
||||
@@ -1,227 +0,0 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
||||
|
||||
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
||||
than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
# sample["semseg_mask"] = cv2.resize(
|
||||
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
# )
|
||||
sample["semseg_mask"] = F.interpolate(
|
||||
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
||||
).numpy()[0, 0]
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
# sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
# print(sample['image'].shape, sample['depth'].shape)
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std."""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
if "semseg_mask" in sample:
|
||||
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
||||
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
||||
|
||||
return sample
|
||||
@@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
"""Bounding box helper class."""
|
||||
|
||||
xmin: int
|
||||
ymin: int
|
||||
xmax: int
|
||||
ymax: int
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Detection result from Grounding DINO."""
|
||||
|
||||
score: float
|
||||
label: str
|
||||
box: BoundingBox
|
||||
model_config = ConfigDict(
|
||||
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||
|
||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class GroundingDinoPipeline(RawModel):
|
||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||
management system.
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||
self._pipeline = pipeline
|
||||
|
||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||
assert results is not None
|
||||
results = [DetectionResult.model_validate(result) for result in results]
|
||||
return results
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||
# CUDA.
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._pipeline.model.to(device=device, dtype=dtype)
|
||||
self._pipeline.device = self._pipeline.model.device
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix the circular import issue.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._pipeline.model)
|
||||
@@ -0,0 +1,50 @@
|
||||
# This file contains utilities for Grounded-SAM mask refinement based on:
|
||||
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
|
||||
"""Convert a binary mask to a polygon.
|
||||
|
||||
Returns:
|
||||
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||
"""
|
||||
# Find contours in the binary mask.
|
||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find the contour with the largest area.
|
||||
largest_contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Extract the vertices of the contour.
|
||||
polygon = largest_contour.reshape(-1, 2).tolist()
|
||||
|
||||
return polygon
|
||||
|
||||
|
||||
def polygon_to_mask(
|
||||
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
|
||||
) -> npt.NDArray[np.uint8]:
|
||||
"""Convert a polygon to a segmentation mask.
|
||||
|
||||
Args:
|
||||
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||
fill_value (int): Value to fill the polygon with.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Segmentation mask with the polygon filled (with value 255).
|
||||
"""
|
||||
# Create an empty mask.
|
||||
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
# Convert polygon to an array of points.
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
|
||||
# Fill the polygon with white color (255).
|
||||
cv2.fillPoly(mask, [pts], color=(fill_value,))
|
||||
|
||||
return mask
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class SegmentAnythingPipeline(RawModel):
|
||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||
|
||||
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
||||
self._sam_model = sam_model
|
||||
self._sam_processor = sam_processor
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||
device = None
|
||||
self._sam_model.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix the circular import issue.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._sam_model)
|
||||
|
||||
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
# Add batch dimension of 1 to the bounding boxes.
|
||||
boxes = [bounding_boxes]
|
||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
masks=outputs.pred_masks,
|
||||
original_sizes=inputs.original_sizes,
|
||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||
)
|
||||
|
||||
# There should be only one batch.
|
||||
assert len(masks) == 1
|
||||
return masks[0]
|
||||
@@ -124,16 +124,14 @@ class IPAdapter(RawModel):
|
||||
self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
def to(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
|
||||
):
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
|
||||
@@ -3,15 +3,15 @@
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@@ -47,9 +47,19 @@ class LoRALayerBase:
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
@@ -57,14 +67,20 @@ class LoRALayerBase:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
@@ -82,14 +98,19 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
@@ -106,19 +127,14 @@ class LoRALayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
@@ -136,20 +152,23 @@ class LoHALayer(LoRALayerBase):
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
@@ -167,23 +186,18 @@ class LoHALayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
@@ -202,37 +216,45 @@ class LoKRLayer(LoRALayerBase):
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
@@ -264,12 +286,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
@@ -277,23 +294,25 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -303,15 +322,12 @@ class FullLayer(LoRALayerBase):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -319,15 +335,10 @@ class FullLayer(LoRALayerBase):
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@@ -345,8 +356,9 @@ class IA3Layer(LoRALayerBase):
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
@@ -359,19 +371,46 @@ class IA3Layer(LoRALayerBase):
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
@@ -390,15 +429,10 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@@ -494,16 +528,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
@@ -511,9 +548,13 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
@@ -521,7 +562,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@@ -353,7 +354,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@@ -364,13 +365,24 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@@ -407,6 +419,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
|
||||
@@ -167,7 +167,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
@@ -289,11 +290,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(
|
||||
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||
)
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class SpandrelImageToImageModelLoader(ModelLoader):
|
||||
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("Unexpected submodel requested for Spandrel model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||
|
||||
torch_dtype = self._torch_dtype
|
||||
if not model.supports_dtype(torch_dtype):
|
||||
self._logger.warning(
|
||||
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
|
||||
"model. Falling back to 'float32'."
|
||||
)
|
||||
torch_dtype = torch.float32
|
||||
model.to(dtype=torch_dtype)
|
||||
|
||||
return model
|
||||
@@ -98,6 +98,9 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
try:
|
||||
|
||||
@@ -11,10 +11,14 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
@@ -33,7 +37,18 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
TextualInversionModelRaw,
|
||||
IPAdapter,
|
||||
LoRAModelRaw,
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
DepthAnythingPipeline,
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
@@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
@@ -220,24 +222,46 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(model_path)
|
||||
return ModelType.SpandrelImageToImage
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
if "No such file or directory" in str(e):
|
||||
# This error is expected if the model_path does not exist (which is the case in some unit tests).
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
||||
@@ -187,157 +187,171 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region ControlNet
|
||||
StarterModel(
|
||||
name="QRCode Monster",
|
||||
name="QRCode Monster v2 (SD1.5)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||
description="Controlnet model that generates scannable creative QR codes",
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="QRCode Monster (SDXL)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
@@ -399,6 +413,43 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
# endregion
|
||||
# region SpandrelImageToImage
|
||||
StarterModel(
|
||||
name="RealESRGAN_x4plus_anime_6B",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="RealESRGAN_x4plus",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (general-purpose).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="ESRGAN_SRx4_DF2KOST_official",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="The official ESRGAN 4x upscaling model.",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="RealESRGAN_x2plus",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
description="A Real-ESRGAN 2x upscaling model (general-purpose).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="SwinIR - realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN-with-dict-keys-params-and-params_ema.pth",
|
||||
description="A SwinIR 4x upscaling model.",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
"""
|
||||
loras = [
|
||||
@@ -32,8 +33,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
"""
|
||||
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
||||
"""A context manager that patches `unet` with the provided attention processor.
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model to patch.
|
||||
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||
"""
|
||||
unet_orig_processors = unet.attn_processors
|
||||
|
||||
# create separate instance for each attention, to be able modify each attention separately
|
||||
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||
try:
|
||||
unet.set_attn_processor(unet_new_processors)
|
||||
yield None
|
||||
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
@@ -66,13 +86,13 @@ class ModelPatcher:
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
model_state_dict=model_state_dict,
|
||||
cached_weights=cached_weights,
|
||||
):
|
||||
yield
|
||||
|
||||
@@ -82,9 +102,9 @@ class ModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@@ -94,7 +114,7 @@ class ModelPatcher:
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
@@ -102,71 +122,26 @@ class ModelPatcher:
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = {}
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
for lora_model, lora_weight in loras:
|
||||
LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||
else:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(
|
||||
device=TorchDevice.CPU_DEVICE,
|
||||
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
|
||||
)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
assert hasattr(layer_weight, "reshape")
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
|
||||
yield # wait for context manager exit
|
||||
yield
|
||||
|
||||
finally:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
||||
@@ -190,12 +190,7 @@ class IAIOnnxRuntimeModel(RawModel):
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with RawModel ABC
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
pass
|
||||
|
||||
# compatability with diffusers load code
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,13 +5,18 @@ import torch
|
||||
|
||||
|
||||
class RawModel(ABC):
|
||||
"""Abstract base class for 'Raw' model wrappers."""
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
pass
|
||||
|
||||
139
invokeai/backend/spandrel_image_to_image_model.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from spandrel import ImageModelDescriptor, ModelLoader
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class SpandrelImageToImageModel(RawModel):
|
||||
"""A wrapper for a Spandrel Image-to-Image model.
|
||||
|
||||
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
|
||||
"""
|
||||
|
||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||
self._spandrel_model = spandrel_model
|
||||
|
||||
@staticmethod
|
||||
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
||||
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
|
||||
|
||||
Args:
|
||||
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
image_np = np.array(image)
|
||||
# (H, W, C) -> (C, H, W)
|
||||
image_np = np.transpose(image_np, (2, 0, 1))
|
||||
image_np = image_np / 255
|
||||
image_tensor = torch.from_numpy(image_np).float()
|
||||
# (C, H, W) -> (N, C, H, W)
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
return image_tensor
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
||||
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
|
||||
Returns:
|
||||
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
"""
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
tensor = tensor.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
tensor = tensor.permute(1, 2, 0)
|
||||
tensor = tensor.clamp(0, 1)
|
||||
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
||||
image = Image.fromarray(tensor)
|
||||
return image
|
||||
|
||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the image-to-image model.
|
||||
|
||||
Args:
|
||||
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
return self._spandrel_model(image_tensor)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str | Path):
|
||||
model = ModelLoader().load_from_file(file_path)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
@classmethod
|
||||
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
model = ModelLoader().load_from_state_dict(state_dict)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
def supports_dtype(self, dtype: torch.dtype) -> bool:
|
||||
"""Check if the model supports the given dtype."""
|
||||
if dtype == torch.float16:
|
||||
return self._spandrel_model.supports_half
|
||||
elif dtype == torch.bfloat16:
|
||||
return self._spandrel_model.supports_bfloat16
|
||||
elif dtype == torch.float32:
|
||||
# All models support float32.
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||
|
||||
def get_model_type_name(self) -> str:
|
||||
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
||||
consistent over time.
|
||||
"""
|
||||
return str(type(self._spandrel_model.model))
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
|
||||
Note: The non_blocking parameter is currently ignored."""
|
||||
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
|
||||
# model directly if we want to apply this optimization.
|
||||
self._spandrel_model.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""The device of the underlying model."""
|
||||
return self._spandrel_model.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""The dtype of the underlying model."""
|
||||
return self._spandrel_model.dtype
|
||||
|
||||
@property
|
||||
def scale(self) -> int:
|
||||
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
|
||||
return self._spandrel_model.scale
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get size of the model in memory in bytes."""
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._spandrel_model.model)
|
||||
@@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"set_seamless",
|
||||
]
|
||||
|
||||
131
invokeai/backend/stable_diffusion/denoise_context.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetKwargs:
|
||||
sample: torch.Tensor
|
||||
timestep: Union[torch.Tensor, float, int]
|
||||
encoder_hidden_states: torch.Tensor
|
||||
|
||||
class_labels: Optional[torch.Tensor] = None
|
||||
timestep_cond: Optional[torch.Tensor] = None
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None
|
||||
# return_dict: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseInputs:
|
||||
"""Initial variables passed to denoise. Supposed to be unchanged."""
|
||||
|
||||
# The latent-space image to denoise.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
# - If we are inpainting, this is the initial latent image before noise has been added.
|
||||
# - If we are generating a new image, this should be initialized to zeros.
|
||||
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||
orig_latents: torch.Tensor
|
||||
|
||||
# kwargs forwarded to the scheduler.step() method.
|
||||
scheduler_step_kwargs: dict[str, Any]
|
||||
|
||||
# Text conditionging data.
|
||||
conditioning_data: TextConditioningData
|
||||
|
||||
# Noise used for two purposes:
|
||||
# 1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||
# 2. Used to noise the `masked_latents` when inpainting.
|
||||
# `noise` should be None if the `latents` tensor has already been noised.
|
||||
# Shape: [1 or batch, channels, latent_height, latent_width]
|
||||
noise: Optional[torch.Tensor]
|
||||
|
||||
# The seed used to generate the noise for the denoising process.
|
||||
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||
seed: int
|
||||
|
||||
# The timestep schedule for the denoising process.
|
||||
timesteps: torch.Tensor
|
||||
|
||||
# The first timestep in the schedule. This is used to determine the initial noise level, so
|
||||
# should be populated if you want noise applied *even* if timesteps is empty.
|
||||
init_timestep: torch.Tensor
|
||||
|
||||
# Class of attention processor that is used.
|
||||
attention_processor_cls: Type[Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenoiseContext:
|
||||
"""Context with all variables in denoise"""
|
||||
|
||||
# Initial variables passed to denoise. Supposed to be unchanged.
|
||||
inputs: DenoiseInputs
|
||||
|
||||
# Scheduler which used to apply noise predictions.
|
||||
scheduler: SchedulerMixin
|
||||
|
||||
# UNet model.
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
# Current state of latent-space image in denoising process.
|
||||
# None until `PRE_DENOISE_LOOP` callback.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latents: Optional[torch.Tensor] = None
|
||||
|
||||
# Current denoising step index.
|
||||
# None until `PRE_STEP` callback.
|
||||
step_index: Optional[int] = None
|
||||
|
||||
# Current denoising step timestep.
|
||||
# None until `PRE_STEP` callback.
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
|
||||
# Arguments which will be passed to UNet model.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
conditioning_mode: Optional[ConditioningMode] = None
|
||||
|
||||
# [TMP] Noise predictions from negative conditioning.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Noise predictions from positive conditioning.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Combined noise prediction from passed conditionings.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Dictionary for extensions to pass extra info about denoise process to other extensions.
|
||||
extra: dict = field(default_factory=dict)
|
||||
@@ -23,21 +23,12 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
step: int
|
||||
order: int
|
||||
total_steps: int
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.Tensor
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -95,6 +102,12 @@ class TextConditioningRegions:
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class ConditioningMode(Enum):
|
||||
Both = "both"
|
||||
Negative = "negative"
|
||||
Positive = "positive"
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -103,7 +116,7 @@ class TextConditioningData:
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
@@ -114,6 +127,7 @@ class TextConditioningData:
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# TODO: old backend, remove
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
@@ -121,3 +135,114 @@ class TextConditioningData:
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
||||
"""Fills unet arguments with data from provided conditionings.
|
||||
|
||||
Args:
|
||||
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
|
||||
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
|
||||
"""
|
||||
_, _, h, w = unet_kwargs.sample.shape
|
||||
device = unet_kwargs.sample.device
|
||||
dtype = unet_kwargs.sample.dtype
|
||||
|
||||
# TODO: combine regions with conditionings
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
conditionings = [self.uncond_text, self.cond_text]
|
||||
c_regions = [self.uncond_regions, self.cond_regions]
|
||||
elif conditioning_mode == ConditioningMode.Positive:
|
||||
conditionings = [self.cond_text]
|
||||
c_regions = [self.cond_regions]
|
||||
elif conditioning_mode == ConditioningMode.Negative:
|
||||
conditionings = [self.uncond_text]
|
||||
c_regions = [self.uncond_regions]
|
||||
else:
|
||||
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
[c.embeds for c in conditionings]
|
||||
)
|
||||
|
||||
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||
|
||||
if self.is_sdxl():
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
|
||||
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
|
||||
)
|
||||
|
||||
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if any(r is not None for r in c_regions):
|
||||
tmp_regions = []
|
||||
for c, r in zip(conditionings, c_regions, strict=True):
|
||||
if r is None:
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
tmp_regions.append(r)
|
||||
|
||||
if unet_kwargs.cross_attention_kwargs is None:
|
||||
unet_kwargs.cross_attention_kwargs = {}
|
||||
|
||||
unet_kwargs.cross_attention_kwargs.update(
|
||||
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
|
||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||
|
||||
@classmethod
|
||||
def _pad_conditioning(
|
||||
cls,
|
||||
cond: torch.Tensor,
|
||||
target_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
|
||||
target_len (int): To which length(tokens count) pad tensor.
|
||||
"""
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < target_len:
|
||||
conditioning_attention_mask = cls._pad_zeros(
|
||||
conditioning_attention_mask,
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = cls._pad_zeros(
|
||||
cond,
|
||||
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return cond, conditioning_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_conditionings_for_batch(
|
||||
cls,
|
||||
conditionings: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Concatenate provided conditioning tensors to one batched tensor.
|
||||
If tensors have different sizes then pad them by zeros and creates
|
||||
encoder_attention_mask to exclude padding from attention.
|
||||
|
||||
Args:
|
||||
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
|
||||
"""
|
||||
encoder_attention_mask = None
|
||||
max_len = max([c.shape[1] for c in conditionings])
|
||||
if any(c.shape[1] != max_len for c in conditionings):
|
||||
encoder_attention_masks = [None] * len(conditionings)
|
||||
for i in range(len(conditionings)):
|
||||
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
|
||||
encoder_attention_mask = torch.cat(encoder_attention_masks)
|
||||
|
||||
return torch.cat(conditionings), encoder_attention_mask
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
|
||||