mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
787 Commits
feat/laten
...
fix/nodes/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fba55a3acc | ||
|
|
1758ff35c3 | ||
|
|
c60a71f0e9 | ||
|
|
ad2247c314 | ||
|
|
9de54b2266 | ||
|
|
5aaaaf64a1 | ||
|
|
9140e2c0f2 | ||
|
|
83e2b7578b | ||
|
|
df1907e849 | ||
|
|
22c337b1aa | ||
|
|
339e7ce213 | ||
|
|
2a178f5a25 | ||
|
|
1bc170727b | ||
|
|
3722cdf5d6 | ||
|
|
42a59aa147 | ||
|
|
b937b7da01 | ||
|
|
21245a0fb2 | ||
|
|
da566b59e8 | ||
|
|
e4dc9c5a04 | ||
|
|
aceadacad4 | ||
|
|
d3dec59cc3 | ||
|
|
6c98700740 | ||
|
|
c4c3c96062 | ||
|
|
6256be480c | ||
|
|
7033071934 | ||
|
|
e48528bbef | ||
|
|
6bdf68dd4c | ||
|
|
0c3616229e | ||
|
|
604cc1adcd | ||
|
|
4847212d5b | ||
|
|
727293d722 | ||
|
|
d2f3500e1b | ||
|
|
ef83a2fffe | ||
|
|
f8d7477c7a | ||
|
|
e374211313 | ||
|
|
01d17601b8 | ||
|
|
bf0d5f4cfc | ||
|
|
663f4935f5 | ||
|
|
9838dda1b7 | ||
|
|
2d889e133d | ||
|
|
6779f1a5ad | ||
|
|
19a6e5dad8 | ||
|
|
285195bf72 | ||
|
|
10008859a4 | ||
|
|
3c04340f3f | ||
|
|
79f0c4d3c4 | ||
|
|
37d4e05838 | ||
|
|
a00ad6ac03 | ||
|
|
2ffead000c | ||
|
|
922319cb84 | ||
|
|
6ee0e197bb | ||
|
|
d3e6f0130c | ||
|
|
421c23d3ea | ||
|
|
4545f3209f | ||
|
|
e2ee8102c2 | ||
|
|
083a0fc4cf | ||
|
|
26b75b85f7 | ||
|
|
f560a462a0 | ||
|
|
d501986610 | ||
|
|
67a75f6895 | ||
|
|
3c032c0767 | ||
|
|
abd6561140 | ||
|
|
bd533426fc | ||
|
|
2489d5459f | ||
|
|
ac477cf5d6 | ||
|
|
be3bdae847 | ||
|
|
3e0ee838cf | ||
|
|
8d3bec57d5 | ||
|
|
cfda128e06 | ||
|
|
661a94b3de | ||
|
|
9ef64016c7 | ||
|
|
21f0d0b0c1 | ||
|
|
8bce234542 | ||
|
|
daadf6ebfd | ||
|
|
fe10a9f747 | ||
|
|
7a2d3f628a | ||
|
|
4defb92105 | ||
|
|
f9f3c91a83 | ||
|
|
95b9c8e505 | ||
|
|
49a02c157b | ||
|
|
d604d986f9 | ||
|
|
70cc037a9c | ||
|
|
e4893e4031 | ||
|
|
4a0a718b96 | ||
|
|
ca8f1a7828 | ||
|
|
2e41af2109 | ||
|
|
bd29e5e655 | ||
|
|
dcfee2e1e4 | ||
|
|
8aac683319 | ||
|
|
d306a84447 | ||
|
|
5865ecd530 | ||
|
|
e1f9685b02 | ||
|
|
498bf0d0ba | ||
|
|
163ef2c941 | ||
|
|
48193b7fa7 | ||
|
|
dd1b3c9f35 | ||
|
|
4b32322a58 | ||
|
|
e06c43adc8 | ||
|
|
c009f46b00 | ||
|
|
748016bdab | ||
|
|
72e9ced889 | ||
|
|
3833304f57 | ||
|
|
4bfaae6617 | ||
|
|
499a174832 | ||
|
|
6ca5ad9075 | ||
|
|
a121e6b3a0 | ||
|
|
207602f425 | ||
|
|
a1671519d5 | ||
|
|
257e972599 | ||
|
|
8639794c12 | ||
|
|
d339c8627f | ||
|
|
a53e0dce6c | ||
|
|
0ae6325353 | ||
|
|
12299120ab | ||
|
|
1a7fe172ca | ||
|
|
4f5693040e | ||
|
|
bb2df88c06 | ||
|
|
41442eb7f6 | ||
|
|
223a679ac1 | ||
|
|
3c60616b4d | ||
|
|
a01998d095 | ||
|
|
7b35162b9e | ||
|
|
c26e1a9271 | ||
|
|
9b32407744 | ||
|
|
82091b9a66 | ||
|
|
f3d9797ebe | ||
|
|
f312e1448f | ||
|
|
a11946f0ad | ||
|
|
80a8d3ef28 | ||
|
|
f4ca9d0e09 | ||
|
|
a960fa009d | ||
|
|
b96b95bc95 | ||
|
|
450641c414 | ||
|
|
94cfcdc411 | ||
|
|
150059f704 | ||
|
|
f1a8b9daee | ||
|
|
be8c0bb952 | ||
|
|
dae5b9b259 | ||
|
|
06428fac67 | ||
|
|
59b5dfc3e0 | ||
|
|
fd981a90be | ||
|
|
6b7cf3f3be | ||
|
|
469dae8c88 | ||
|
|
9d4b84ef68 | ||
|
|
4cbc802e36 | ||
|
|
5f2d07917d | ||
|
|
5c740452f6 | ||
|
|
82c2498043 | ||
|
|
0497bea264 | ||
|
|
b8e32fa459 | ||
|
|
34ebee67b7 | ||
|
|
e0c998d192 | ||
|
|
b51e9a6bdb | ||
|
|
09f396ce84 | ||
|
|
abee37eab3 | ||
|
|
42e48b2bef | ||
|
|
70ece4364c | ||
|
|
f9d5f9d52c | ||
|
|
d0ee3558d1 | ||
|
|
587297878a | ||
|
|
b4c998a9ae | ||
|
|
88e8e3977b | ||
|
|
24b86cffe9 | ||
|
|
a1773197e9 | ||
|
|
1e08d865c9 | ||
|
|
f8bb650cc1 | ||
|
|
2cee8bebb2 | ||
|
|
ade4ec5fd8 | ||
|
|
70ffd6b03f | ||
|
|
6c551df311 | ||
|
|
24f605629e | ||
|
|
2af1ec9d02 | ||
|
|
79d53341de | ||
|
|
e40b3506c4 | ||
|
|
33912382e3 | ||
|
|
d282810e53 | ||
|
|
9df502fc77 | ||
|
|
705573f0a8 | ||
|
|
1878ea94f6 | ||
|
|
4ba5086b9a | ||
|
|
4a991b4daa | ||
|
|
80474d26f9 | ||
|
|
9a77bd9140 | ||
|
|
14cdc800c3 | ||
|
|
9cfbea4c25 | ||
|
|
5fe674e223 | ||
|
|
32200efce8 | ||
|
|
68a02da990 | ||
|
|
5b20766ea3 | ||
|
|
9a914250a0 | ||
|
|
0e3106f631 | ||
|
|
6c5954f9d1 | ||
|
|
740c05a0bb | ||
|
|
26090011c4 | ||
|
|
c9ae26a176 | ||
|
|
e7db6d8120 | ||
|
|
a6af7e8824 | ||
|
|
87ba17a1f5 | ||
|
|
c7ea46a5da | ||
|
|
1439dc7712 | ||
|
|
46cac6468e | ||
|
|
2a814d886b | ||
|
|
60a2fbec41 | ||
|
|
f15a328b80 | ||
|
|
811d9ab55a | ||
|
|
e00fed5c46 | ||
|
|
a3fa38b353 | ||
|
|
2e42a4bdd9 | ||
|
|
36f72b5a49 | ||
|
|
af42d7d347 | ||
|
|
8607b1994c | ||
|
|
36eb1bd893 | ||
|
|
9fa78443de | ||
|
|
893f776f1d | ||
|
|
e051c450ed | ||
|
|
50135b726e | ||
|
|
085ab54124 | ||
|
|
8e1a56875e | ||
|
|
000626ab2e | ||
|
|
694fd0c92f | ||
|
|
c647056287 | ||
|
|
738ba40f51 | ||
|
|
3ce3a7ee72 | ||
|
|
74b43c9bdf | ||
|
|
3d2ff7755e | ||
|
|
a87d52a389 | ||
|
|
959e64c9b3 | ||
|
|
2c056ead42 | ||
|
|
30f20b55d5 | ||
|
|
1bca32ed16 | ||
|
|
7f91139e21 | ||
|
|
c53b7c7389 | ||
|
|
93f3658a4a | ||
|
|
68be95acbb | ||
|
|
813f79f0f9 | ||
|
|
c3ec86bc70 | ||
|
|
05a19753c6 | ||
|
|
a33327c651 | ||
|
|
6ad7cc4f2a | ||
|
|
c506355b8b | ||
|
|
d54168b8fb | ||
|
|
c91b071c47 | ||
|
|
9c57b18008 | ||
|
|
69539a0472 | ||
|
|
7bce455d16 | ||
|
|
3f45294c61 | ||
|
|
fd03c7eebe | ||
|
|
07c49a5726 | ||
|
|
8c688f8e29 | ||
|
|
887576d217 | ||
|
|
6652f3405b | ||
|
|
3d13167d32 | ||
|
|
f2bb507ebb | ||
|
|
fe8f3381fc | ||
|
|
2a6d11e645 | ||
|
|
01f46d3c7d | ||
|
|
5f76b62553 | ||
|
|
4bbe3b0d00 | ||
|
|
9ed86a08f1 | ||
|
|
68405910ba | ||
|
|
0a50e2638c | ||
|
|
fc7c5da4dd | ||
|
|
a3357e073c | ||
|
|
d114833a12 | ||
|
|
96038bd075 | ||
|
|
2f383c2598 | ||
|
|
702a8d1f72 | ||
|
|
0a8390356f | ||
|
|
844058c0a5 | ||
|
|
7d74cbe29c | ||
|
|
62ac0ed2dc | ||
|
|
ae14adec2a | ||
|
|
6c2b39d1df | ||
|
|
0843028e6e | ||
|
|
de0fd87035 | ||
|
|
8b6c0be259 | ||
|
|
58fec84858 | ||
|
|
f223ad7776 | ||
|
|
00eabf630d | ||
|
|
6245a27650 | ||
|
|
fa1ac57c90 | ||
|
|
0f16b1c98d | ||
|
|
08e66c5451 | ||
|
|
563bf70c95 | ||
|
|
49d29420c4 | ||
|
|
ae9d0c6c1b | ||
|
|
04f9757f8d | ||
|
|
1f9e1eb964 | ||
|
|
d8d11f9bbb | ||
|
|
13fa0d3bc0 | ||
|
|
5eeb4b8e06 | ||
|
|
f5044c290d | ||
|
|
1b43276e5d | ||
|
|
294f086857 | ||
|
|
e5024bf5e9 | ||
|
|
79198b4bba | ||
|
|
1a2f0984db | ||
|
|
454683e6eb | ||
|
|
bbb2a08e8f | ||
|
|
bf116927e1 | ||
|
|
3d249c4fa3 | ||
|
|
fa338ddb6a | ||
|
|
b200451330 | ||
|
|
8283d23b74 | ||
|
|
2fc0a4d53b | ||
|
|
3ff732d583 | ||
|
|
840c632c0a | ||
|
|
40d6e4f287 | ||
|
|
fc5f9c30a6 | ||
|
|
229de2dbb8 | ||
|
|
cc22427f25 | ||
|
|
90333c0074 | ||
|
|
54e5301b35 | ||
|
|
b31fc43bfa | ||
|
|
9bcf0b2251 | ||
|
|
d4bc98c383 | ||
|
|
bc892c535c | ||
|
|
099e1e7c08 | ||
|
|
b1000e30c1 | ||
|
|
7bd94eac0e | ||
|
|
2c77563dcc | ||
|
|
603c9a587e | ||
|
|
1a5a2dfda9 | ||
|
|
090b7eeaf3 | ||
|
|
117536324c | ||
|
|
999c092b6a | ||
|
|
9e31b1f387 | ||
|
|
cb157ea530 | ||
|
|
5f6f38074d | ||
|
|
25b8dd340a | ||
|
|
fb06f5b892 | ||
|
|
1a7fb601dc | ||
|
|
cdcfda164d | ||
|
|
966b154a1f | ||
|
|
95fa66661c | ||
|
|
6247b79111 | ||
|
|
5831364f9c | ||
|
|
919b81cff1 | ||
|
|
065fff7db5 | ||
|
|
a664ee30a2 | ||
|
|
03f3ad435a | ||
|
|
2270c270ef | ||
|
|
4f7820719b | ||
|
|
fa285883ad | ||
|
|
474fca8e6a | ||
|
|
5dc0250b00 | ||
|
|
f269377a01 | ||
|
|
d0406024e3 | ||
|
|
aa3a969bd2 | ||
|
|
73a95973a8 | ||
|
|
bf4fe3c1ac | ||
|
|
d6c08ba469 | ||
|
|
69f0ba65f1 | ||
|
|
828c86964d | ||
|
|
54b7ddd63f | ||
|
|
a0dde66b5d | ||
|
|
b6b3b9f99c | ||
|
|
faa69f8a47 | ||
|
|
d92c7f5483 | ||
|
|
6b824eb112 | ||
|
|
72b4371804 | ||
|
|
fa290aff8d | ||
|
|
3d99d7ae8b | ||
|
|
2eb367969c | ||
|
|
9cdad95f48 | ||
|
|
707ed39300 | ||
|
|
6bbb5f061a | ||
|
|
6896e69e95 | ||
|
|
b17f4c1650 | ||
|
|
98493ed9e2 | ||
|
|
94c953deab | ||
|
|
fa4d88e163 | ||
|
|
b1e1e3efc7 | ||
|
|
3b9426eb72 | ||
|
|
e2e07696fc | ||
|
|
d6a959b000 | ||
|
|
c3935d3849 | ||
|
|
383e3d77cb | ||
|
|
31e97ead2a | ||
|
|
0b49995659 | ||
|
|
ff204db6b2 | ||
|
|
f74f3d6a3a | ||
|
|
713fb061e8 | ||
|
|
77b7680b32 | ||
|
|
ff63433591 | ||
|
|
31281d7181 | ||
|
|
8285fbb0b1 | ||
|
|
951e6b746c | ||
|
|
44a6623094 | ||
|
|
72d1e4e404 | ||
|
|
91918e648b | ||
|
|
1390b65a9c | ||
|
|
82231369d3 | ||
|
|
7620bacc01 | ||
|
|
ea9cf04765 | ||
|
|
47301e6f85 | ||
|
|
f143fb7254 | ||
|
|
2bdb655375 | ||
|
|
41f7758977 | ||
|
|
8ae1eaaccc | ||
|
|
98773b20ac | ||
|
|
d66979073b | ||
|
|
c9e621093e | ||
|
|
e06ba40795 | ||
|
|
6571e4c2fd | ||
|
|
ff9240b51d | ||
|
|
18466e01fd | ||
|
|
e9821ab711 | ||
|
|
d6530df635 | ||
|
|
b47786e846 | ||
|
|
062b2cf46f | ||
|
|
082ecf6f25 | ||
|
|
1632ac6b9f | ||
|
|
69ccd3a0b5 | ||
|
|
877959b413 | ||
|
|
6e60f7517b | ||
|
|
296ee6b7ea | ||
|
|
7c7ffddb2b | ||
|
|
e1ae7842ff | ||
|
|
9687fe7bac | ||
|
|
a9a2bd90c2 | ||
|
|
47ca71a7eb | ||
|
|
a9c47237b1 | ||
|
|
33bbae2f47 | ||
|
|
fab7a1d337 | ||
|
|
cffcf80977 | ||
|
|
1a3fd05b81 | ||
|
|
c22c6ca135 | ||
|
|
3afb6a387f | ||
|
|
33e5ed7180 | ||
|
|
2067757fab | ||
|
|
b1b94a3d56 | ||
|
|
c9ee42450e | ||
|
|
10fe31c2a1 | ||
|
|
420a76ecdd | ||
|
|
79de9047b5 | ||
|
|
dc54cbb1fc | ||
|
|
070218aba7 | ||
|
|
f1c226b171 | ||
|
|
7004430380 | ||
|
|
1ddc620192 | ||
|
|
a7cebbd970 | ||
|
|
d97438b0b3 | ||
|
|
4522f3f4c9 | ||
|
|
6fe28980b0 | ||
|
|
4aec5d8ffc | ||
|
|
bbb4e8f5ef | ||
|
|
bce33ea62e | ||
|
|
e4705d5ce7 | ||
|
|
6764b2a854 | ||
|
|
970340cf62 | ||
|
|
043f9d9ba4 | ||
|
|
6f82801d07 | ||
|
|
3e3dd39ae4 | ||
|
|
89aa06e014 | ||
|
|
6cc00ef4b7 | ||
|
|
f31e62afad | ||
|
|
38fd2ad45d | ||
|
|
05b99b5377 | ||
|
|
08a14ee6d5 | ||
|
|
29fcc92da9 | ||
|
|
d78e3572e3 | ||
|
|
160267c71a | ||
|
|
fd47e70c92 | ||
|
|
9317b42e5f | ||
|
|
bdab73701f | ||
|
|
3ea5e78322 | ||
|
|
f609ee21a2 | ||
|
|
f51defeeb3 | ||
|
|
ee0225f4ba | ||
|
|
33a0af4637 | ||
|
|
d37b08a7dd | ||
|
|
9a796364da | ||
|
|
1ad4eb3a7b | ||
|
|
3767a453bb | ||
|
|
b0892d30a4 | ||
|
|
d9b1e4a98c | ||
|
|
a4dec8c1d6 | ||
|
|
8960ceb98b | ||
|
|
be79d088c0 | ||
|
|
009407ea3f | ||
|
|
6999d28c7f | ||
|
|
324e9eb74b | ||
|
|
56cff40362 | ||
|
|
2ba40c5e52 | ||
|
|
3ab147204c | ||
|
|
e4c89cba9c | ||
|
|
322ea84c4e | ||
|
|
f2b41c60ff | ||
|
|
754acec92f | ||
|
|
11fc7e40a5 | ||
|
|
d15bb88eb2 | ||
|
|
70ba36eefc | ||
|
|
7e70391c2b | ||
|
|
e2a94be336 | ||
|
|
63a86eefb4 | ||
|
|
b0727b9d47 | ||
|
|
d96e727dd5 | ||
|
|
fe480886dc | ||
|
|
8031d1827b | ||
|
|
b5acdb322d | ||
|
|
a4d1fe8819 | ||
|
|
10b7a58887 | ||
|
|
901a277959 | ||
|
|
aaa093bef1 | ||
|
|
bb96543d66 | ||
|
|
a2a2cfa765 | ||
|
|
18e6a2b410 | ||
|
|
db27263bc2 | ||
|
|
0e027ec3ef | ||
|
|
5acbbeecaa | ||
|
|
6ef2168b67 | ||
|
|
6d958a214c | ||
|
|
4ae4bf4ff9 | ||
|
|
fdef53b2de | ||
|
|
11bd038b9d | ||
|
|
768cfe3aab | ||
|
|
c4277b0662 | ||
|
|
020f3ccf07 | ||
|
|
7467fa5e57 | ||
|
|
e19ef7ed2f | ||
|
|
71003be6b8 | ||
|
|
c1dbafc2df | ||
|
|
dcebd71381 | ||
|
|
d855a65e73 | ||
|
|
a9007c7e0f | ||
|
|
af60304f97 | ||
|
|
6de241eead | ||
|
|
51032dc0b2 | ||
|
|
9ec3d2bc0c | ||
|
|
297931f5d9 | ||
|
|
f613c073c1 | ||
|
|
63d248622c | ||
|
|
48485fe92f | ||
|
|
07726af703 | ||
|
|
ad1004b485 | ||
|
|
0096fb2790 | ||
|
|
9c8c2e49d6 | ||
|
|
2005a96847 | ||
|
|
00a8d60c1b | ||
|
|
3aa182390a | ||
|
|
e44f1d6d4e | ||
|
|
dfdf8e2ead | ||
|
|
3a645c4e80 | ||
|
|
113129daf9 | ||
|
|
940e3b6635 | ||
|
|
7fb29dabff | ||
|
|
714ad6dbb8 | ||
|
|
c0863fa20f | ||
|
|
78b0b37ba6 | ||
|
|
5d5cdc7716 | ||
|
|
93cd818f6a | ||
|
|
598a628790 | ||
|
|
f3666eda63 | ||
|
|
754017b59e | ||
|
|
21251ce12c | ||
|
|
dc12fa6cd6 | ||
|
|
f2f4c37f19 | ||
|
|
0864fca641 | ||
|
|
5e4c0217c7 | ||
|
|
78cd106c23 | ||
|
|
6ed0efa938 | ||
|
|
ca0669c337 | ||
|
|
b59a749627 | ||
|
|
a91dee87d0 | ||
|
|
5ff98a4179 | ||
|
|
36b2f12219 | ||
|
|
5569f205ee | ||
|
|
a76cf8aab2 | ||
|
|
5c0f0d1808 | ||
|
|
951900a86a | ||
|
|
582f516fef | ||
|
|
a25bae2545 | ||
|
|
0ea35b1e3d | ||
|
|
c6f935bf1a | ||
|
|
96b4d35d43 | ||
|
|
7b0938e7e4 | ||
|
|
249522b568 | ||
|
|
39088e42cc | ||
|
|
30e0033ebe | ||
|
|
b599c40099 | ||
|
|
8f190169db | ||
|
|
1d4d705795 | ||
|
|
b3f71b3078 | ||
|
|
6059db4f15 | ||
|
|
0d5f44b153 | ||
|
|
17164a37a8 | ||
|
|
f88ccabe30 | ||
|
|
e1c85f1234 | ||
|
|
f50293920e | ||
|
|
1e2db3a17f | ||
|
|
57a3eb3652 | ||
|
|
82a8972bde | ||
|
|
497a885c85 | ||
|
|
4d9f55d0f6 | ||
|
|
5f8f51436a | ||
|
|
0c3b4bb70d | ||
|
|
33e13820fc | ||
|
|
43d991cfdb | ||
|
|
291e9cf14b | ||
|
|
a2de5c9963 | ||
|
|
5025f84627 | ||
|
|
d2c8a53c55 | ||
|
|
5659d10778 | ||
|
|
46cab81d6f | ||
|
|
dd157bce85 | ||
|
|
2f25dd7d0d | ||
|
|
e56965ad76 | ||
|
|
2273b3a8c8 | ||
|
|
05fb0ac2b2 | ||
|
|
d4acd49ee3 | ||
|
|
d98868e524 | ||
|
|
93bb27f2c7 | ||
|
|
a4c44edf8d | ||
|
|
1e94d7739a | ||
|
|
9110838fe4 | ||
|
|
ca7b267326 | ||
|
|
7f5992d6a5 | ||
|
|
88776fb2de | ||
|
|
34f567abd4 | ||
|
|
b87f3043ae | ||
|
|
3829ffbe66 | ||
|
|
ad619ae880 | ||
|
|
d22ebe08be | ||
|
|
ee0c6ad86e | ||
|
|
96adb56633 | ||
|
|
3000436121 | ||
|
|
37cdd91f5d | ||
|
|
6f3c6ddf3f | ||
|
|
0bfbda512d | ||
|
|
295b98a13c | ||
|
|
ff6b345d45 | ||
|
|
1fb307abf4 | ||
|
|
29c952dcf6 | ||
|
|
010f63a50d | ||
|
|
068bbe3a39 | ||
|
|
ad39680feb | ||
|
|
1e0ae8404c | ||
|
|
460d555a3d | ||
|
|
66ad04fcfc | ||
|
|
c7c0836721 | ||
|
|
d2c223de8f | ||
|
|
dd16f788ed | ||
|
|
b25c1af018 | ||
|
|
8f393b64b8 | ||
|
|
55b3193629 | ||
|
|
6f78c073ed | ||
|
|
c406be6f4f | ||
|
|
aeaf3737aa | ||
|
|
23d9d58c08 | ||
|
|
4c331a5d7e | ||
|
|
035425ef24 | ||
|
|
021e5a2aa3 | ||
|
|
7a1de3887e | ||
|
|
4a7a5234df | ||
|
|
6aebe1614d | ||
|
|
74292eba28 | ||
|
|
c31ff364ab | ||
|
|
f310a39381 | ||
|
|
5a7e611e0a | ||
|
|
4e29a751d8 | ||
|
|
3f94f81acd | ||
|
|
5de3c41d19 | ||
|
|
f071b03ceb | ||
|
|
b9375186a5 | ||
|
|
11bd932cba | ||
|
|
b77ccfaf32 | ||
|
|
96653eebb6 | ||
|
|
60d25f105f | ||
|
|
734b653a5f | ||
|
|
52c9e6ec91 | ||
|
|
c0f132e41a | ||
|
|
cc1160a43a | ||
|
|
adde8450bc | ||
|
|
5bf9891553 | ||
|
|
22c34c343a | ||
|
|
f7804f6126 | ||
|
|
d14b02e93f | ||
|
|
1b75d899ae | ||
|
|
d4aa79acd7 | ||
|
|
33d199c007 | ||
|
|
9c89d3452c | ||
|
|
fb0b63c580 | ||
|
|
bb2c6e5925 | ||
|
|
928caff2a6 | ||
|
|
670c79f2c7 | ||
|
|
d6efb98953 | ||
|
|
19da795274 | ||
|
|
454ba9b893 | ||
|
|
8e419a4f97 | ||
|
|
2533209326 | ||
|
|
d2dc1ed26f | ||
|
|
d4fb16825e | ||
|
|
165c1adcf8 | ||
|
|
650d69ef5b | ||
|
|
ff0e79fa9a | ||
|
|
127b54f812 | ||
|
|
bdf33f13b3 | ||
|
|
27241cdde1 | ||
|
|
259d6ec90d | ||
|
|
a77c4c87b2 | ||
|
|
d96175d127 | ||
|
|
7025c00581 | ||
|
|
b1a99d772c | ||
|
|
7ea995149e | ||
|
|
fd82763412 | ||
|
|
f9710dd6ed | ||
|
|
4e7dd7d3f6 | ||
|
|
20ca9e1fc1 | ||
|
|
8a8b09a953 | ||
|
|
9e4e386c9b | ||
|
|
eca1e449a8 | ||
|
|
ffaadb9d05 | ||
|
|
8adff96e29 | ||
|
|
7593dc19d6 | ||
|
|
b7c5a39685 | ||
|
|
bd1b84f7d0 | ||
|
|
eadfd239a8 | ||
|
|
e971a7f35c | ||
|
|
8d75e50435 | ||
|
|
6ab84741a0 | ||
|
|
cd16857f38 | ||
|
|
1442f1cb8d | ||
|
|
eea0d6f7bc | ||
|
|
1d9c115225 | ||
|
|
4fe94a9315 | ||
|
|
30af20a056 | ||
|
|
7ef0d2aa35 | ||
|
|
c8f765cc06 | ||
|
|
b9e9087dbe | ||
|
|
63e465eb5c | ||
|
|
426f4eaf7e | ||
|
|
baf5451fa0 | ||
|
|
1103ab2844 | ||
|
|
11b2076b46 | ||
|
|
b31a6ff605 | ||
|
|
1f602e6143 | ||
|
|
039fa73269 | ||
|
|
2204e47596 | ||
|
|
d8b1f29066 | ||
|
|
b23c9f1da5 | ||
|
|
5e8e3cf464 | ||
|
|
72967bf118 | ||
|
|
bc96727cbe | ||
|
|
3b2a054f7a | ||
|
|
131145eab1 | ||
|
|
4492044d29 | ||
|
|
5431dd5f50 | ||
|
|
79fecba274 | ||
|
|
2ef79b8bf3 | ||
|
|
11ecf438f5 | ||
|
|
df5b968954 | ||
|
|
8ad8c5c67a | ||
|
|
590942edd7 | ||
|
|
4627910c5d | ||
|
|
fa6a580452 | ||
|
|
99c692f397 | ||
|
|
3d85e769ce | ||
|
|
9cb962cad7 | ||
|
|
a108155544 | ||
|
|
c15b49c805 | ||
|
|
fd63e36822 | ||
|
|
4649920074 | ||
|
|
667171ed90 | ||
|
|
f28632980d | ||
|
|
42d938fda5 | ||
|
|
25ce47c44f | ||
|
|
647ffb2a0f | ||
|
|
afd2e32092 | ||
|
|
05a27bda5e | ||
|
|
a8cfa3565c | ||
|
|
e0214a32bc | ||
|
|
af8c7c7d29 | ||
|
|
a4e36bc02a | ||
|
|
2e9bec15e7 | ||
|
|
68bc0112fa | ||
|
|
742ed19d66 | ||
|
|
29c2ada23c | ||
|
|
e4196bbe5b | ||
|
|
15ffb53e59 | ||
|
|
90054ddf0d | ||
|
|
a273bdbdc1 | ||
|
|
8a0ec0fa0f | ||
|
|
e1fed52c66 | ||
|
|
bb959448c1 | ||
|
|
2e2abf6ea6 | ||
|
|
956ad6bcf5 |
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@@ -2,7 +2,7 @@
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @tildebyte @blessedcoolant
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
@@ -18,17 +18,17 @@
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||
/invokeai/frontend/training @lstein @blessedcoolant
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
|
||||
|
||||
|
||||
|
||||
19
.github/workflows/test-invoke-pip.yml
vendored
19
.github/workflows/test-invoke-pip.yml
vendored
@@ -80,11 +80,6 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: set test prompt to Pull Request validation
|
||||
if: ${{ github.ref != 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
@@ -105,12 +100,6 @@ jobs:
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
id: run-preload-models
|
||||
env:
|
||||
@@ -129,15 +118,21 @@ jobs:
|
||||
HF_HUB_OFFLINE: 1
|
||||
HF_DATASETS_OFFLINE: 1
|
||||
TRANSFORMERS_OFFLINE: 1
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
run: >
|
||||
invokeai
|
||||
--no-patchmatch
|
||||
--no-nsfw_checker
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
--precision=float32
|
||||
--always_use_cpu
|
||||
--use_memory_db
|
||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
|
||||
- name: Archive results
|
||||
id: archive-results
|
||||
env:
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -201,6 +201,8 @@ checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/web/dist/*
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
|
||||
17
README.md
17
README.md
@@ -43,6 +43,23 @@ _Note: InvokeAI is rapidly evolving. Please use the
|
||||
[Issues](https://github.com/invoke-ai/InvokeAI/issues) tab to report bugs and make feature
|
||||
requests. Be sure to use the provided templates. They will help us diagnose issues faster._
|
||||
|
||||
## FOR DEVELOPERS - MIGRATING TO THE 3.0.0 MODELS FORMAT
|
||||
|
||||
The models directory and models.yaml have changed. To migrate to the
|
||||
new layout, please follow this recipe:
|
||||
|
||||
1. Run `python scripts/migrate_models_to_3.0.py <path_to_root_directory>
|
||||
|
||||
2. This will create a new models directory named `models-3.0` and a
|
||||
new config directory named `models.yaml-3.0`, both in the current
|
||||
working directory. If you prefer to name them something else, pass
|
||||
the `--dest-directory` and/or `--dest-yaml` arguments.
|
||||
|
||||
3. Check that the new models directory and yaml file look ok.
|
||||
|
||||
4. Replace the existing directory and file, keeping backup copies just in
|
||||
case.
|
||||
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
Binary file not shown.
@@ -1,164 +0,0 @@
|
||||
@echo off
|
||||
|
||||
@rem This script will install git (if not found on the PATH variable)
|
||||
@rem using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
@rem For users who already have git, this step will be skipped.
|
||||
|
||||
@rem Next, it'll download the project's source code.
|
||||
@rem Then it will download a self-contained, standalone Python and unpack it.
|
||||
@rem Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
@rem This enables a user to install this project without manually installing git or Python
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
echo ***** Installing InvokeAI.. *****
|
||||
@rem Config
|
||||
set INSTALL_ENV_DIR=%cd%\installer_files\env
|
||||
@rem https://mamba.readthedocs.io/en/latest/installation.html
|
||||
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
|
||||
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
|
||||
|
||||
set PACKAGES_TO_INSTALL=
|
||||
|
||||
call git --version >.tmp1 2>.tmp2
|
||||
if "%ERRORLEVEL%" NEQ "0" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% git
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem (if necessary) install git into a contained environment
|
||||
if "%PACKAGES_TO_INSTALL%" NEQ "" (
|
||||
@rem download micromamba
|
||||
echo ***** Downloading micromamba from %MICROMAMBA_DOWNLOAD_URL% to micromamba.exe *****
|
||||
|
||||
call curl -L "%MICROMAMBA_DOWNLOAD_URL%" > micromamba.exe
|
||||
|
||||
@rem test the mamba binary
|
||||
echo ***** Micromamba version: *****
|
||||
call micromamba.exe --version
|
||||
|
||||
@rem create the installer env
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
call micromamba.exe create -y --prefix "%INSTALL_ENV_DIR%"
|
||||
)
|
||||
|
||||
echo ***** Packages to install:%PACKAGES_TO_INSTALL% *****
|
||||
|
||||
call micromamba.exe install -y --prefix "%INSTALL_ENV_DIR%" -c conda-forge %PACKAGES_TO_INSTALL%
|
||||
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
echo ----- There was a problem while installing "%PACKAGES_TO_INSTALL%" using micromamba. Cannot continue. -----
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
)
|
||||
|
||||
del /q micromamba.exe
|
||||
|
||||
@rem For 'git' only
|
||||
set PATH=%INSTALL_ENV_DIR%\Library\bin;%PATH%
|
||||
|
||||
@rem Download/unpack/clean up InvokeAI release sourceball
|
||||
set err_msg=----- InvokeAI source download failed -----
|
||||
echo Trying to download "%RELEASE_URL%%RELEASE_SOURCEBALL%"
|
||||
curl -L %RELEASE_URL%%RELEASE_SOURCEBALL% --output InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- InvokeAI source unpack failed -----
|
||||
tar -zxf InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q InvokeAI.tgz
|
||||
|
||||
set err_msg=----- InvokeAI source copy failed -----
|
||||
cd InvokeAI-*
|
||||
xcopy . .. /e /h
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
cd ..
|
||||
|
||||
@rem cleanup
|
||||
for /f %%i in ('dir /b InvokeAI-*') do rd /s /q %%i
|
||||
rd /s /q .dev_scripts .github docker-build tests
|
||||
del /q requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo ***** Unpacked InvokeAI source *****
|
||||
|
||||
@rem Download/unpack/clean up python-build-standalone
|
||||
set err_msg=----- Python download failed -----
|
||||
curl -L %PYTHON_BUILD_STANDALONE_URL%/%PYTHON_BUILD_STANDALONE% --output python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- Python unpack failed -----
|
||||
tar -zxf python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q python.tgz
|
||||
|
||||
echo ***** Unpacked python-build-standalone *****
|
||||
|
||||
@rem create venv
|
||||
set err_msg=----- problem creating venv -----
|
||||
.\python\python -E -s -m venv .venv
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo ***** Created Python virtual environment *****
|
||||
|
||||
@rem Print venv's Python version
|
||||
set err_msg=----- problem calling venv's python -----
|
||||
echo We're running under
|
||||
.venv\Scripts\python --version
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- pip update failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location --upgrade pip wheel
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Updated pip and wheel *****
|
||||
|
||||
set err_msg=----- requirements file copy failed -----
|
||||
copy binary_installer\py3.10-windows-x86_64-cuda-reqs.txt requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- main pip install failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -r requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Installed Python dependencies *****
|
||||
|
||||
set err_msg=----- InvokeAI setup failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -e .
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
copy binary_installer\invoke.bat.in .\invoke.bat
|
||||
echo ***** Installed invoke launcher script ******
|
||||
|
||||
@rem more cleanup
|
||||
rd /s /q binary_installer installer_files
|
||||
|
||||
@rem preload the models
|
||||
call .venv\Scripts\python ldm\invoke\config\invokeai_configure.py
|
||||
set err_msg=----- model download clone failed -----
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
deactivate
|
||||
|
||||
echo ***** Finished downloading models *****
|
||||
|
||||
echo All done! Execute the file invoke.bat in this directory to start InvokeAI
|
||||
pause
|
||||
exit
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
pause
|
||||
exit
|
||||
@@ -1,235 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
set -euo pipefail
|
||||
IFS=$'\n\t'
|
||||
|
||||
function _err_exit {
|
||||
if test "$1" -ne 0
|
||||
then
|
||||
echo -e "Error code $1; Error caught was '$2'"
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
fi
|
||||
}
|
||||
|
||||
# This script will install git (if not found on the PATH variable)
|
||||
# using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
# For users who already have git, this step will be skipped.
|
||||
|
||||
# Next, it'll download the project's source code.
|
||||
# Then it will download a self-contained, standalone Python and unpack it.
|
||||
# Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
# This enables a user to install this project without manually installing git or Python
|
||||
|
||||
echo -e "\n***** Installing InvokeAI into $(pwd)... *****\n"
|
||||
|
||||
export no_cache_dir="--no-cache-dir"
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ "$1" = "use-cache" ]; then
|
||||
export no_cache_dir=""
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
OS_NAME=$(uname -s)
|
||||
case "${OS_NAME}" in
|
||||
Linux*) OS_NAME="linux";;
|
||||
Darwin*) OS_NAME="darwin";;
|
||||
*) echo -e "\n----- Unknown OS: $OS_NAME! This script runs only on Linux or macOS -----\n" && exit
|
||||
esac
|
||||
|
||||
OS_ARCH=$(uname -m)
|
||||
case "${OS_ARCH}" in
|
||||
x86_64*) ;;
|
||||
arm64*) ;;
|
||||
*) echo -e "\n----- Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64 -----\n" && exit
|
||||
esac
|
||||
|
||||
# https://mamba.readthedocs.io/en/latest/installation.html
|
||||
MAMBA_OS_NAME=$OS_NAME
|
||||
MAMBA_ARCH=$OS_ARCH
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
MAMBA_OS_NAME="osx"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "linux" ]; then
|
||||
MAMBA_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "x86_64" ]; then
|
||||
MAMBA_ARCH="64"
|
||||
fi
|
||||
|
||||
PY_ARCH=$OS_ARCH
|
||||
if [ "$OS_ARCH" == "arm64" ]; then
|
||||
PY_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
# Compute device ('cd' segment of reqs files) detect goes here
|
||||
# This needs a ton of work
|
||||
# Suggestions:
|
||||
# - lspci
|
||||
# - check $PATH for nvidia-smi, gtt CUDA/GPU version from output
|
||||
# - Surely there's a similar utility for AMD?
|
||||
CD="cuda"
|
||||
if [ "$OS_NAME" == "darwin" ] && [ "$OS_ARCH" == "arm64" ]; then
|
||||
CD="mps"
|
||||
fi
|
||||
|
||||
# config
|
||||
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
|
||||
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
|
||||
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
|
||||
elif [ "$OS_NAME" == "linux" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-unknown-linux-gnu-install_only.tar.gz
|
||||
fi
|
||||
echo "INSTALLING $RELEASE_SOURCEBALL FROM $RELEASE_URL"
|
||||
|
||||
PACKAGES_TO_INSTALL=""
|
||||
|
||||
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
|
||||
|
||||
# (if necessary) install git and conda into a contained environment
|
||||
if [ "$PACKAGES_TO_INSTALL" != "" ]; then
|
||||
# download micromamba
|
||||
echo -e "\n***** Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to micromamba *****\n"
|
||||
|
||||
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > micromamba
|
||||
|
||||
chmod u+x ./micromamba
|
||||
|
||||
# test the mamba binary
|
||||
echo -e "\n***** Micromamba version: *****\n"
|
||||
./micromamba --version
|
||||
|
||||
# create the installer env
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
./micromamba create -y --prefix "$INSTALL_ENV_DIR"
|
||||
fi
|
||||
|
||||
echo -e "\n***** Packages to install:$PACKAGES_TO_INSTALL *****\n"
|
||||
|
||||
./micromamba install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge "$PACKAGES_TO_INSTALL"
|
||||
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
echo -e "\n----- There was a problem while initializing micromamba. Cannot continue. -----\n"
|
||||
exit
|
||||
fi
|
||||
fi
|
||||
|
||||
rm -f micromamba.exe
|
||||
|
||||
export PATH="$INSTALL_ENV_DIR/bin:$PATH"
|
||||
|
||||
# Download/unpack/clean up InvokeAI release sourceball
|
||||
_err_msg="\n----- InvokeAI source download failed -----\n"
|
||||
curl -L $RELEASE_URL/$RELEASE_SOURCEBALL --output InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- InvokeAI source unpack failed -----\n"
|
||||
tar -zxf InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f InvokeAI.tgz
|
||||
|
||||
_err_msg="\n----- InvokeAI source copy failed -----\n"
|
||||
cd InvokeAI-*
|
||||
cp -r . ..
|
||||
_err_exit $? _err_msg
|
||||
cd ..
|
||||
|
||||
# cleanup
|
||||
rm -rf InvokeAI-*/
|
||||
rm -rf .dev_scripts/ .github/ docker-build/ tests/ requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo -e "\n***** Unpacked InvokeAI source *****\n"
|
||||
|
||||
# Download/unpack/clean up python-build-standalone
|
||||
_err_msg="\n----- Python download failed -----\n"
|
||||
curl -L $PYTHON_BUILD_STANDALONE_URL/$PYTHON_BUILD_STANDALONE --output python.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- Python unpack failed -----\n"
|
||||
tar -zxf python.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f python.tgz
|
||||
|
||||
echo -e "\n***** Unpacked python-build-standalone *****\n"
|
||||
|
||||
# create venv
|
||||
_err_msg="\n----- problem creating venv -----\n"
|
||||
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
# patch sysconfig so that extensions can build properly
|
||||
# adapted from https://github.com/cashapp/hermit-packages/commit/fcba384663892f4d9cfb35e8639ff7a28166ee43
|
||||
PYTHON_INSTALL_DIR="$(pwd)/python"
|
||||
SYSCONFIG="$(echo python/lib/python*/_sysconfigdata_*.py)"
|
||||
TMPFILE="$(mktemp)"
|
||||
chmod +w "${SYSCONFIG}"
|
||||
cp "${SYSCONFIG}" "${TMPFILE}"
|
||||
sed "s,'/install,'${PYTHON_INSTALL_DIR},g" "${TMPFILE}" > "${SYSCONFIG}"
|
||||
rm -f "${TMPFILE}"
|
||||
fi
|
||||
|
||||
./python/bin/python3 -E -s -m venv .venv
|
||||
_err_exit $? _err_msg
|
||||
source .venv/bin/activate
|
||||
|
||||
echo -e "\n***** Created Python virtual environment *****\n"
|
||||
|
||||
# Print venv's Python version
|
||||
_err_msg="\n----- problem calling venv's python -----\n"
|
||||
echo -e "We're running under"
|
||||
.venv/bin/python3 --version
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- pip update failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location --upgrade pip
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Updated pip *****\n"
|
||||
|
||||
_err_msg="\n----- requirements file copy failed -----\n"
|
||||
cp binary_installer/py3.10-${OS_NAME}-"${OS_ARCH}"-${CD}-reqs.txt requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- main pip install failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -r requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed Python dependencies *****\n"
|
||||
|
||||
_err_msg="\n----- InvokeAI setup failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -e .
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed InvokeAI *****\n"
|
||||
|
||||
cp binary_installer/invoke.sh.in ./invoke.sh
|
||||
chmod a+rx ./invoke.sh
|
||||
echo -e "\n***** Installed invoke launcher script ******\n"
|
||||
|
||||
# more cleanup
|
||||
rm -rf binary_installer/ installer_files/
|
||||
|
||||
# preload the models
|
||||
.venv/bin/python3 scripts/configure_invokeai.py
|
||||
_err_msg="\n----- model download clone failed -----\n"
|
||||
_err_exit $? _err_msg
|
||||
deactivate
|
||||
|
||||
echo -e "\n***** Finished downloading models *****\n"
|
||||
|
||||
echo "All done! Run the command"
|
||||
echo " $scriptdir/invoke.sh"
|
||||
echo "to start InvokeAI."
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
@@ -1,36 +0,0 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo OR
|
||||
echo 3. open the developer console
|
||||
set /p choice="Please enter 1, 2 or 3: "
|
||||
if /i "%choice%" == "1" (
|
||||
echo Starting the InvokeAI command-line.
|
||||
.venv\Scripts\python scripts\invoke.py %*
|
||||
) else if /i "%choice%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI.
|
||||
.venv\Scripts\python scripts\invoke.py --web %*
|
||||
) else if /i "%choice%" == "3" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) else (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
deactivate
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
set -eu
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
# set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line"
|
||||
echo "2. browser-based UI"
|
||||
echo "OR"
|
||||
echo "3. open the developer console"
|
||||
echo "Please enter 1, 2, or 3:"
|
||||
read choice
|
||||
|
||||
case $choice in
|
||||
1)
|
||||
printf "\nStarting the InvokeAI command-line..\n";
|
||||
.venv/bin/python scripts/invoke.py $*;
|
||||
;;
|
||||
2)
|
||||
printf "\nStarting the InvokeAI browser-based UI..\n";
|
||||
.venv/bin/python scripts/invoke.py --web $*;
|
||||
;;
|
||||
3)
|
||||
printf "\nDeveloper Console:\n";
|
||||
printf "Python command is:\n\t";
|
||||
which python;
|
||||
printf "Python version is:\n\t";
|
||||
python --version;
|
||||
echo "*************************"
|
||||
echo "You are now in your user shell ($SHELL) with the local InvokeAI Python virtual environment activated,";
|
||||
echo "so that you can troubleshoot this InvokeAI installation as necessary.";
|
||||
printf "*************************\n"
|
||||
echo "*** Type \`exit\` to quit this shell and deactivate the Python virtual environment *** ";
|
||||
/usr/bin/env "$SHELL";
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection";
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,17 +0,0 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Installation on Windows:
|
||||
NOTE: You might need to enable Windows Long Paths. If you're not sure,
|
||||
then you almost certainly need to. Simply double-click the 'WinLongPathsEnabled.reg'
|
||||
file. Note that you will need to have admin privileges in order to
|
||||
do this.
|
||||
|
||||
Please double-click the 'install.bat' file (while keeping it inside the invokeAI folder).
|
||||
|
||||
Installation on Linux and Mac:
|
||||
Please open the terminal, and run './install.sh' (while keeping it inside the invokeAI folder).
|
||||
|
||||
After installation, please run the 'invoke.bat' file (on Windows) or 'invoke.sh'
|
||||
file (on Linux/Mac) to start InvokeAI.
|
||||
@@ -1,33 +0,0 @@
|
||||
--prefer-binary
|
||||
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
--trusted-host https://download.pytorch.org
|
||||
accelerate~=0.15
|
||||
albumentations
|
||||
diffusers[torch]~=0.11
|
||||
einops
|
||||
eventlet
|
||||
flask_cors
|
||||
flask_socketio
|
||||
flaskwebgui==1.0.3
|
||||
getpass_asterisk
|
||||
imageio-ffmpeg
|
||||
pyreadline3
|
||||
realesrgan
|
||||
send2trash
|
||||
streamlit
|
||||
taming-transformers-rom1504
|
||||
test-tube
|
||||
torch-fidelity
|
||||
torch==1.12.1 ; platform_system == 'Darwin'
|
||||
torch==1.12.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
torchvision==0.13.1 ; platform_system == 'Darwin'
|
||||
torchvision==0.13.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
transformers
|
||||
picklescan
|
||||
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
|
||||
https://github.com/invoke-ai/clipseg/archive/1f754751c85d7d4255fa681f4491ff5711c1c288.zip
|
||||
https://github.com/invoke-ai/GFPGAN/archive/3f5d2397361199bc4a91c08bb7d80f04d7805615.zip ; platform_system=='Windows'
|
||||
https://github.com/invoke-ai/GFPGAN/archive/c796277a1cf77954e5fc0b288d7062d162894248.zip ; platform_system=='Linux' or platform_system=='Darwin'
|
||||
https://github.com/Birch-san/k-diffusion/archive/363386981fee88620709cf8f6f2eea167bd6cd74.zip
|
||||
https://github.com/invoke-ai/PyPatchMatch/archive/129863937a8ab37f6bbcec327c994c0f932abdbc.zip
|
||||
@@ -19,31 +19,56 @@ An invocation looks like this:
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
|
||||
# fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Each portion is important to implement correctly.
|
||||
@@ -95,25 +120,67 @@ Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
If the `name` also matches, then the field can be **automatically linked** to a
|
||||
previous invocation by name and matching.
|
||||
|
||||
### Config
|
||||
|
||||
```py
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
This is an optional configuration for the invocation. It inherits from
|
||||
pydantic's model `Config` class, and it used primarily to customize the
|
||||
autogenerated OpenAPI schema.
|
||||
|
||||
The UI relies on the OpenAPI schema in two ways:
|
||||
|
||||
- An API client & Typescript types are generated from it. This happens at build
|
||||
time.
|
||||
- The node editor parses the schema into a template used by the UI to create the
|
||||
node editor UI. This parsing happens at runtime.
|
||||
|
||||
In this example, a `ui` key has been added to the `schema_extra` dict to provide
|
||||
some tags for the UI, to facilitate filtering nodes.
|
||||
|
||||
See the Schema Generation section below for more information.
|
||||
|
||||
### Invoke Function
|
||||
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -135,9 +202,16 @@ scenarios. If you need functionality, please provide it as a service in the
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
```
|
||||
|
||||
Output classes look like an invocation class without the invoke method. Prefer
|
||||
@@ -168,35 +242,36 @@ Here's that `ImageOutput` class, without the needed schema customisation:
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
```
|
||||
|
||||
The generated OpenAPI schema, and all clients/types generated from it, will have
|
||||
the `type` and `image` properties marked as optional, even though we know they
|
||||
will always have a value by the time we can interact with them via the API.
|
||||
|
||||
Here's the same class, but with the schema customisation added:
|
||||
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
|
||||
`image`, `width` and `height` properties marked as optional, even though we know
|
||||
they will always have a value.
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
# Add schema customization
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
```
|
||||
|
||||
The resultant schema (and any API client or types generated from it) will now
|
||||
have see `type` as string literal `"image"` and `image` as an `ImageField`
|
||||
object.
|
||||
With the customization in place, the schema will now show these properties as
|
||||
required, obviating the need for extensive null checks in client code.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
|
||||
171
docs/features/LOGGING.md
Normal file
171
docs/features/LOGGING.md
Normal file
@@ -0,0 +1,171 @@
|
||||
---
|
||||
title: Controlling Logging
|
||||
---
|
||||
|
||||
# :material-image-off: Controlling Logging
|
||||
|
||||
## Controlling How InvokeAI Logs Status Messages
|
||||
|
||||
InvokeAI logs status messages using a configurable logging system. You
|
||||
can log to the terminal window, to a designated file on the local
|
||||
machine, to the syslog facility on a Linux or Mac, or to a properly
|
||||
configured web server. You can configure several logs at the same
|
||||
time, and control the level of message logged and the logging format
|
||||
(to a limited extent).
|
||||
|
||||
Three command-line options control logging:
|
||||
|
||||
### `--log_handlers <handler1> <handler2> ...`
|
||||
|
||||
This option activates one or more log handlers. Options are "console",
|
||||
"file", "syslog" and "http". To specify more than one, separate them
|
||||
by spaces:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
|
||||
```
|
||||
|
||||
The format of these options is described below.
|
||||
|
||||
### `--log_format {plain|color|legacy|syslog}`
|
||||
|
||||
This controls the format of log messages written to the console. Only
|
||||
the "console" log handler is currently affected by this setting.
|
||||
|
||||
* "plain" provides formatted messages like this:
|
||||
|
||||
```bash
|
||||
|
||||
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
|
||||
```
|
||||
|
||||
* "color" produces similar output, but the text will be color coded to
|
||||
indicate the severity of the message.
|
||||
|
||||
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
|
||||
|
||||
```bash
|
||||
### this is a critical error
|
||||
*** this is an error
|
||||
** this is a warning
|
||||
>> this is an informational messages
|
||||
| this is a debug message
|
||||
```
|
||||
|
||||
* "syslog" produces messages suitable for syslog entries:
|
||||
|
||||
```bash
|
||||
InvokeAI [2691178] <CRITICAL> this is a critical error
|
||||
InvokeAI [2691178] <ERROR> this is an error
|
||||
InvokeAI [2691178] <WARNING> this is a warning
|
||||
InvokeAI [2691178] <INFO> this is an informational messages
|
||||
InvokeAI [2691178] <DEBUG> this is a debug message
|
||||
```
|
||||
|
||||
(note that the date, time and hostname will be added by the syslog
|
||||
system)
|
||||
|
||||
### `--log_level {debug|info|warning|error|critical}`
|
||||
|
||||
Providing this command-line option will cause only messages at the
|
||||
specified level or above to be emitted.
|
||||
|
||||
## Console logging
|
||||
|
||||
When "console" is provided to `--log_handlers`, messages will be
|
||||
written to the command line window in which InvokeAI was launched. By
|
||||
default, the color formatter will be used unless overridden by
|
||||
`--log_format`.
|
||||
|
||||
## File logging
|
||||
|
||||
When "file" is provided to `--log_handlers`, entries will be written
|
||||
to the file indicated in the path argument. By default, the "plain"
|
||||
format will be used:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
## Syslog logging
|
||||
|
||||
When "syslog" is requested, entries will be sent to the syslog
|
||||
system. There are a variety of ways to control where the log message
|
||||
is sent:
|
||||
|
||||
* Send to the local machine using the `/dev/log` socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=/dev/log
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message on a nonstandard
|
||||
port:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost:512
|
||||
```
|
||||
|
||||
* Send to a remote machine named "loghost" on the local LAN using
|
||||
facility LOG_USER and UDP packets:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
|
||||
```
|
||||
|
||||
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
|
||||
are defaults.
|
||||
|
||||
* Send to a remote machine named "loghost" using the facility LOCAL0
|
||||
and using a TCP socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
|
||||
```
|
||||
|
||||
If no arguments are specified (just a bare "syslog"), then the logging
|
||||
system will look for a UNIX socket named `/dev/log`, and if not found
|
||||
try to send a UDP message to `localhost`. The Macintosh OS used to
|
||||
support logging to a socket named `/var/run/syslog`, but this feature
|
||||
has since been disabled.
|
||||
|
||||
## Web logging
|
||||
|
||||
If you have access to a web server that is configured to log messages
|
||||
when a particular URL is requested, you can log using the "http"
|
||||
method:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
|
||||
```
|
||||
|
||||
The optional [,method=] part can be used to specify whether the URL
|
||||
accepts GET (default) or POST messages.
|
||||
|
||||
Currently password authentication and SSL are not supported.
|
||||
|
||||
## Using the configuration file
|
||||
|
||||
You can set and forget logging options by adding a "Logging" section
|
||||
to `invokeai.yaml`:
|
||||
|
||||
```
|
||||
InvokeAI:
|
||||
[... other settings...]
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=/dev/log
|
||||
log_level: info
|
||||
log_format: color
|
||||
```
|
||||
@@ -57,6 +57,9 @@ Personalize models by adding your own style or subjects.
|
||||
## * [The NSFW Checker](NSFW.md)
|
||||
Prevent InvokeAI from displaying unwanted racy images.
|
||||
|
||||
## * [Controlling Logging](LOGGING.md)
|
||||
Control how InvokeAI logs status messages.
|
||||
|
||||
## * [Miscellaneous](OTHER.md)
|
||||
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
||||
batch process a file of prompts, increase the "creativity" of image
|
||||
|
||||
@@ -67,7 +67,7 @@ title: Home
|
||||
implementation of Stable Diffusion, the open source text-to-image and
|
||||
image-to-image generator. It provides a streamlined process with various new
|
||||
features and options to aid the image generation process. It runs on Windows,
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB or RAM.
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
|
||||
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a
|
||||
|
||||
@@ -216,7 +216,7 @@ manager, please follow these steps:
|
||||
9. Run the command-line- or the web- interface:
|
||||
|
||||
From within INVOKEAI_ROOT, activate the environment
|
||||
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
|
||||
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
|
||||
the script `invokeai`. If the virtual environment you selected is NOT inside
|
||||
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
|
||||
`--root_dir \path\to\invokeai` to the commands below:
|
||||
|
||||
@@ -25,7 +25,7 @@ done
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python 3.9 or higher before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
@@ -7,42 +7,42 @@ call .venv\Scripts\activate.bat
|
||||
set INVOKEAI_ROOT=.
|
||||
|
||||
:start
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line interface
|
||||
echo 2. browser-based UI
|
||||
echo 3. run textual inversion training
|
||||
echo 4. merge models (diffusers type only)
|
||||
echo 5. download and install models
|
||||
echo 6. change InvokeAI startup options
|
||||
echo 7. re-run the configure script to fix a broken install
|
||||
echo 8. open the developer console
|
||||
echo 9. update InvokeAI
|
||||
echo 10. command-line help
|
||||
echo Q - quit
|
||||
set /P restore="Please enter 1-10, Q: [2] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Explore InvokeAI nodes using a command-line interface
|
||||
echo 3. Run textual inversion training
|
||||
echo 4. Merge models (diffusers type only)
|
||||
echo 5. Download and install models
|
||||
echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [2] "
|
||||
if not defined choice set choice=2
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Starting the InvokeAI command-line..
|
||||
python .venv\Scripts\invokeai.exe %*
|
||||
) ELSE IF /I "%restore%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai.exe --web %*
|
||||
) ELSE IF /I "%restore%" == "3" (
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
) ELSE IF /I "%restore%" == "4" (
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
) ELSE IF /I "%restore%" == "5" (
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
echo Running invokeai-model-install...
|
||||
python .venv\Scripts\invokeai-model-install.exe
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
) ELSE IF /I "%choice%" == "6" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%restore%" == "7" (
|
||||
) ELSE IF /I "%choice%" == "7" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --default_only
|
||||
) ELSE IF /I "%restore%" == "8" (
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@@ -54,15 +54,15 @@ IF /I "%restore%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%restore%" == "9" (
|
||||
) ELSE IF /I "%choice%" == "9" (
|
||||
echo Running invokeai-update...
|
||||
python .venv\Scripts\invokeai-update.exe %*
|
||||
) ELSE IF /I "%restore%" == "10" (
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%restore%" == "q" (
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
echo Goodbye!
|
||||
goto ending
|
||||
) ELSE (
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT License
|
||||
|
||||
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
|
||||
# Copyright 2023, The InvokeAI Development Team
|
||||
|
||||
####
|
||||
# This launch script assumes that:
|
||||
# 1. it is located in the runtime directory,
|
||||
@@ -11,85 +16,168 @@
|
||||
|
||||
set -eu
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
export INVOKEAI_ROOT="$scriptdir"
|
||||
PARAMS=$@
|
||||
|
||||
# set required env var for torch on mac MPS
|
||||
# Check to see if dialog is installed (it seems to be fairly standard, but good to check regardless) and if the user has passed the --no-tui argument to disable the dialog TUI
|
||||
tui=true
|
||||
if command -v dialog &>/dev/null; then
|
||||
# This must use $@ to properly loop through the arguments passed by the user
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" == "--no-tui" ]; then
|
||||
tui=false
|
||||
# Remove the --no-tui argument to avoid errors later on when passing arguments to InvokeAI
|
||||
PARAMS=$(echo "$PARAMS" | sed 's/--no-tui//')
|
||||
break
|
||||
fi
|
||||
done
|
||||
else
|
||||
tui=false
|
||||
fi
|
||||
|
||||
# Set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true
|
||||
do
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line interface"
|
||||
echo "2. browser-based UI"
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. merge models (diffusers type only)"
|
||||
echo "5. download and install models"
|
||||
echo "6. change InvokeAI startup options"
|
||||
echo "7. re-run the configure script to fix a broken install"
|
||||
echo "8. open the developer console"
|
||||
echo "9. update InvokeAI"
|
||||
echo "10. command-line help"
|
||||
echo "Q - Quit"
|
||||
echo ""
|
||||
read -p "Please enter 1-10, Q: [2] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1)
|
||||
echo "Starting the InvokeAI command-line..."
|
||||
invokeai $@
|
||||
;;
|
||||
2)
|
||||
echo "Starting the InvokeAI browser-based UI..."
|
||||
invokeai --web $@
|
||||
;;
|
||||
3)
|
||||
echo "Starting Textual Inversion:"
|
||||
invokeai-ti --gui $@
|
||||
;;
|
||||
4)
|
||||
echo "Merging Models:"
|
||||
invokeai-merge --gui $@
|
||||
;;
|
||||
5)
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
echo "Developer Console:"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
echo "Update:"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
invokeai --help
|
||||
;;
|
||||
[qQ])
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection"
|
||||
exit;;
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
1)
|
||||
clear
|
||||
printf "Generate images with a browser-based interface\n"
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Explore InvokeAI nodes using a command-line interface\n"
|
||||
invokeai $PARAMS
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Textual inversion training\n"
|
||||
invokeai-ti --gui $PARAMS
|
||||
;;
|
||||
4)
|
||||
clear
|
||||
printf "Merge models (diffusers type only)\n"
|
||||
invokeai-merge --gui $PARAMS
|
||||
;;
|
||||
5)
|
||||
clear
|
||||
printf "Download and install models\n"
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
clear
|
||||
printf "Change InvokeAI startup options\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
clear
|
||||
printf "Update InvokeAI\n"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
;;
|
||||
"HELP 1")
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
printf "Exiting...\n"
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
done
|
||||
clear
|
||||
}
|
||||
|
||||
# Dialog-based TUI for launcing Invoke functions
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Generate images using a command-line interface"
|
||||
3 "Textual inversion training"
|
||||
4 "Merge models (diffusers type only)"
|
||||
5 "Download and install models"
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
|
||||
choice=$(dialog --clear \
|
||||
--backtitle "\Zb\Zu\Z3InvokeAI" \
|
||||
--colors \
|
||||
--title "What would you like to do?" \
|
||||
--ok-label "Run" \
|
||||
--cancel-label "Exit" \
|
||||
--help-button \
|
||||
--help-label "CLI Help" \
|
||||
--menu "Select an option:" \
|
||||
0 0 0 \
|
||||
"${options[@]}" \
|
||||
2>&1 >/dev/tty) || clear
|
||||
do_choice "$choice"
|
||||
clear
|
||||
}
|
||||
|
||||
# Command-line interface for launching Invoke functions
|
||||
do_line_input() {
|
||||
clear
|
||||
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Explore InvokeAI nodes using the command-line interface\n"
|
||||
printf "3: Run textual inversion training\n"
|
||||
printf "4: Merge models (diffusers type only)\n"
|
||||
printf "5: Download and install models\n"
|
||||
printf "6: Change InvokeAI startup options\n"
|
||||
printf "7: Re-run the configure script to fix a broken install\n"
|
||||
printf "8: Open the developer console\n"
|
||||
printf "9: Update InvokeAI\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
}
|
||||
|
||||
# Main IF statement for launching Invoke with either the TUI or CLI, and for checking if the user is in the developer console
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true; do
|
||||
if $tui; then
|
||||
# .dialogrc must be located in the same directory as the invoke.sh script
|
||||
export DIALOGRC="./.dialogrc"
|
||||
do_dialog
|
||||
else
|
||||
do_line_input
|
||||
fi
|
||||
done
|
||||
else # in developer console
|
||||
python --version
|
||||
echo "Press ^D to exit"
|
||||
printf "Press ^D to exit\n"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
|
||||
@@ -1,23 +1,34 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.board_images import (
|
||||
BoardImagesService,
|
||||
BoardImagesServiceDependencies,
|
||||
)
|
||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -37,54 +48,91 @@ def check_internet() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TO DO: Use the config to select the logger rather than use the default
|
||||
# invokeai logging module
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||
logger.info(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
board_images = BoardImagesService(
|
||||
services=BoardImagesServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
services=ImageServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config,logger),
|
||||
model_manager=ModelManagerService(config,logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger),
|
||||
restoration=RestorationServices(config, logger),
|
||||
configuration=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class SavedImage(BaseModel):
|
||||
image_name: str = Field(description="The name of the saved image")
|
||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
||||
created: int = Field(description="The created timestamp of the saved image")
|
||||
69
invokeai/app/api/routers/board_images.py
Normal file
69
invokeai/app/api/routers/board_images.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
|
||||
@board_images_router.get(
|
||||
"/{board_id}",
|
||||
operation_id="list_board_images",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_board_images(
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of boards per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
|
||||
board_id,
|
||||
)
|
||||
return results
|
||||
|
||||
108
invokeai/app/api/routers/boards.py
Normal file
108
invokeai/app/api/routers/boards.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Optional, Union
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
"/",
|
||||
operation_id="create_board",
|
||||
responses={
|
||||
201: {"description": "The board was created successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def create_board(
|
||||
board_name: str = Query(description="The name of the board to create"),
|
||||
) -> BoardDTO:
|
||||
"""Creates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||
async def get_board(
|
||||
board_id: str = Path(description="The id of board to get"),
|
||||
) -> BoardDTO:
|
||||
"""Gets a board"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
operation_id="update_board",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The board was updated successfully",
|
||||
},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def update_board(
|
||||
board_id: str = Path(description="The id of board to update"),
|
||||
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board")
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
) -> None:
|
||||
"""Deletes a board"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@boards_router.get(
|
||||
"/",
|
||||
operation_id="list_boards",
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all()
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
|
||||
)
|
||||
@@ -1,148 +1,241 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from typing import Optional
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import (
|
||||
ImageResponse,
|
||||
ImageResponseMetadata,
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image"""
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_type: ImageType = Path(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image and its thumbnail"""
|
||||
|
||||
ApiDependencies.invoker.services.images.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
|
||||
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a thumbnail"""
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
"/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The image was uploaded successfully",
|
||||
"model": ImageResponse,
|
||||
},
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
415: {"description": "Image upload failed"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def upload_image(
|
||||
file: UploadFile, image_type: ImageType, request: Request, response: Response
|
||||
) -> ImageResponse:
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(contents))
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
except:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.EXTERNAL,
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
saved_image = ApiDependencies.invoker.services.images.save(
|
||||
image_type, filename, img
|
||||
)
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_dto.image_url
|
||||
|
||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||
return image_dto
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
image_type, saved_image.image_name
|
||||
)
|
||||
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
image_type, saved_image.image_name, True
|
||||
)
|
||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=image_type,
|
||||
image_name=saved_image.image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
metadata=ImageResponseMetadata(
|
||||
created=saved_image.created,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_url
|
||||
|
||||
return res
|
||||
@images_router.patch(
|
||||
"/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name)
|
||||
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=image_name,
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the image thumbnail",
|
||||
"content": {"image/webp": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_thumbnail(
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path, media_type="image/webp", content_disposition_type="inline"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
return ImageUrlsDTO(
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
operation_id="list_images_with_metadata",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(
|
||||
default=ImageType.RESULT, description="The type of images to get"
|
||||
async def list_images_with_metadata(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include"
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board id to filter by"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
||||
return result
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, Optional, Union, Dict
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@@ -19,6 +20,15 @@ class VaeRepo(BaseModel):
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
model_name: str = Field(description="The name of the model")
|
||||
model_type: str = Field(description="The type of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['folder'] = 'folder'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
@@ -29,12 +39,8 @@ class CkptModelInfo(ModelInfo):
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
class SafetensorsModelInfo(CkptModelInfo):
|
||||
format: Literal['safetensors'] = 'safetensors'
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
@@ -56,7 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@@ -64,9 +70,16 @@ class ModelsList(BaseModel):
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models() -> ModelsList:
|
||||
async def list_models(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@@ -121,7 +134,7 @@ async def delete_model(model_name: str) -> None:
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error(f"Model not found")
|
||||
logger.error("Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@@ -11,11 +11,20 @@ from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
#This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.routers import sessions, models, images, boards, board_images
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
@@ -33,30 +42,21 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id, logger=logger
|
||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
@@ -74,10 +74,13 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@@ -117,6 +120,22 @@ def custom_openapi():
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
type="string",
|
||||
enum=list(v.value for v in model_config_format_enum),
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
@@ -124,7 +143,7 @@ def custom_openapi():
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
@@ -143,19 +162,21 @@ def overridden_redoc():
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
|
||||
app.mount("/",
|
||||
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||
html=True
|
||||
), name="ui"
|
||||
)
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -285,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand):
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
get_subactions = action._get_subactions
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._indent()
|
||||
if isinstance(action, argparse._SubParsersAction):
|
||||
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
||||
yield subaction
|
||||
else:
|
||||
for subaction in get_subactions():
|
||||
yield subaction
|
||||
self._dedent()
|
||||
|
||||
@@ -11,9 +11,10 @@ from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager, Globals
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
@@ -131,13 +132,13 @@ class Completer(object):
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(model_manager)
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
@@ -153,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(Globals.root, ".invoke_history")
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
|
||||
@@ -4,35 +4,44 @@ import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
from typing import Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||
create_system_graphs)
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
||||
GraphInvocation, LibraryGraph,
|
||||
are_connection_types_compatible)
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
|
||||
@@ -43,7 +52,6 @@ class CliCommand(BaseModel):
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
@@ -64,7 +72,7 @@ def add_invocation_args(command_parser):
|
||||
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
@@ -187,49 +195,71 @@ def invoke_all(context: CliContext):
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
set_autocompleter(model_manager)
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
set_autocompleter(services)
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
@@ -241,10 +271,18 @@ def invoke_cli():
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
set_autocompleter(services)
|
||||
|
||||
while True:
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
try:
|
||||
cmd_input = input("invoke> ")
|
||||
if command_line_args_exist:
|
||||
cmd_input = invocation_commands.pop(0)
|
||||
done = len(invocation_commands) == 0
|
||||
else:
|
||||
cmd_input = input("invoke> ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl-c exits
|
||||
break
|
||||
@@ -368,6 +406,9 @@ def invoke_cli():
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except ValidationError:
|
||||
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@@ -75,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
#fmt: on
|
||||
|
||||
|
||||
@@ -92,6 +96,7 @@ class UIConfig(TypedDict, total=False):
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from pydantic import Field, validator
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
@@ -22,9 +22,17 @@ class IntCollectionOutput(BaseInvocationOutput):
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
@@ -33,12 +41,34 @@ class RangeInvocation(BaseInvocation):
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
type: Literal["range_of_size"] = "range_of_size"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.start + self.size, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from contextlib import ExitStack
|
||||
import re
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
Fragment, Conjunction,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
@@ -41,7 +42,7 @@ class CompelInvocation(BaseInvocation):
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@@ -57,87 +58,93 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer,\
|
||||
text_encoder_info as text_encoder,\
|
||||
ExitStack() as stack:
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
|
||||
# TODO: support legacy blend?
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in blend.prompts
|
||||
]
|
||||
)
|
||||
elif type(prompt) is Conjunction:
|
||||
conjunction: Conjunction = prompt
|
||||
return sum(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in conjunction.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
@@ -172,6 +179,22 @@ def get_tokens_for_prompt_object(
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts)>1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(
|
||||
p,
|
||||
tokenizer,
|
||||
display_label_prefix=this_display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
|
||||
454
invokeai/app/invocations/controlnet_image_processors.py
Normal file
454
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
LineartDetector,
|
||||
LineartAnimeDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
ContentShuffleDetector,
|
||||
ZoeDetector,
|
||||
MediapipeFaceDetector,
|
||||
)
|
||||
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: move image processors to separate file (image_analysis.py
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# FIXME: what happened to image metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width = image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height = image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies HED edge detection to image"""
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
coarse=self.coarse)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies NormalBae processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies MLSD processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies PIDI processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies content shuffle processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
@@ -7,9 +7,9 @@ import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
@@ -26,24 +26,23 @@ class CvInvocationConfig(BaseModel):
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
"""Simple inpaint using opencv."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
@@ -52,18 +51,17 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=image_inpainted,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_inpainted,
|
||||
)
|
||||
@@ -3,120 +3,77 @@
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
ResourceOrigin)
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
import re
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
from .latent import get_scheduler
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
context: OldModelContext
|
||||
|
||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.context = OldModelContext(
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(BaseInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=generate_output.image,
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
@@ -128,92 +85,41 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_blur: int = Field(
|
||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||
)
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||
seam_steps: int = Field(
|
||||
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||
)
|
||||
tile_size: int = Field(
|
||||
default=32, ge=1, description="The tile infill method size (px)"
|
||||
)
|
||||
infill_method: INFILL_METHODS = Field(
|
||||
default=DEFAULT_INFILL_METHOD,
|
||||
description="The method used to infill empty regions (px)",
|
||||
)
|
||||
inpaint_width: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the inpaint region (px)",
|
||||
)
|
||||
inpaint_height: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the inpaint region (px)",
|
||||
)
|
||||
inpaint_fill: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The solid infill method color",
|
||||
)
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
@@ -221,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -234,60 +148,101 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context):
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
|
||||
#unet = unet_info.context.model
|
||||
#vae = vae_info.context.model
|
||||
|
||||
with ExitStack() as stack:
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with vae_info as vae,\
|
||||
unet_info as unet,\
|
||||
ModelPatcher.apply_lora_unet(unet, loras):
|
||||
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
conditioning = self.get_conditioning(context)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
@@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput):
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
image_type: ImageType, image_name: str, image: Image.Image
|
||||
) -> ImageOutput:
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
image_field = ImageField(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
@@ -80,16 +67,17 @@ class LoadImageInvocation(BaseInvocation):
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image_type, self.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image_type,
|
||||
image_name=self.image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,32 +87,32 @@ class ShowImageInvocation(BaseInvocation):
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image.image_type,
|
||||
image_name=self.image.image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
@@ -132,58 +120,51 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
)
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image_dto = context.services.images.create(
|
||||
image=image_crop,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@@ -199,20 +180,19 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image_dto = context.services.images.create(
|
||||
image=new_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -223,48 +203,150 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=image_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return MaskOutput(
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
|
||||
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
image2 = context.services.images.get_pil_image(self.image2.image_name)
|
||||
|
||||
multiply_image = ImageChops.multiply(image1, image2)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=multiply_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=channel_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=converted_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
@@ -273,74 +355,171 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=blur_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
PIL_RESAMPLING_MODES = Literal[
|
||||
"nearest",
|
||||
"box",
|
||||
"bilinear",
|
||||
"hamming",
|
||||
"bicubic",
|
||||
"lanczos",
|
||||
]
|
||||
|
||||
|
||||
PIL_RESAMPLING_MAP = {
|
||||
"nearest": Image.Resampling.NEAREST,
|
||||
"box": Image.Resampling.BOX,
|
||||
"bilinear": Image.Resampling.BILINEAR,
|
||||
"hamming": Image.Resampling.HAMMING,
|
||||
"bicubic": Image.Resampling.BICUBIC,
|
||||
"lanczos": Image.Resampling.LANCZOS,
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
resize_image = image.resize(
|
||||
(self.width, self.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
width = int(image.width * self.scale_factor)
|
||||
height = int(image.height * self.scale_factor)
|
||||
|
||||
resize_image = image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=lerp_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
@@ -352,16 +531,17 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=ilerp_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
from typing import Literal, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageField, ImageType
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
@@ -125,36 +125,35 @@ class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: Optional[ColorField] = Field(
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
color: ColorField = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image)
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -163,7 +162,9 @@ class InfillTileInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
@@ -173,29 +174,26 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -204,30 +202,29 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional, Union
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
@@ -81,10 +83,17 @@ SAMPLER_NAME_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
@@ -119,7 +128,6 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
@@ -139,12 +147,17 @@ class NoiseInvocation(BaseInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||
return v % SEED_MAX
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(choose_torch_device())
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
@@ -160,20 +173,36 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -189,49 +218,138 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
scheduler,
|
||||
|
||||
# for ddim scheduler
|
||||
eta=0.0, #ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
#configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
#)
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
if control_input is None:
|
||||
# print("control input is None")
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
# print("control input is empty list")
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
# print("control input is ControlField")
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
# print("control input is list[ControlField]")
|
||||
control_list = control_input
|
||||
else:
|
||||
# print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
if (control_list is None):
|
||||
control_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(control_name,
|
||||
subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
@@ -243,27 +361,46 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
@@ -271,7 +408,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@@ -279,7 +416,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -295,31 +434,58 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=unet.device,
|
||||
)
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
@@ -331,16 +497,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@@ -348,30 +512,44 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
with vae_info as vae:
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=image
|
||||
)
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
image = vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
@@ -404,7 +582,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
@@ -434,7 +613,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
@@ -445,38 +625,50 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
with vae_info as vae:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = 0.18215 * latents
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, latents)
|
||||
# context.services.latents.set(name, latents)
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -5,7 +5,12 @@ from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
@@ -22,19 +27,30 @@ class MathInvocationConfig(BaseModel):
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
@@ -42,11 +58,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
@@ -54,11 +71,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
@@ -66,11 +84,12 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
@@ -78,8 +97,13 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
#fmt: on
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
|
||||
217
invokeai/app/invocations/model.py
Normal file
217
invokeai/app/invocations/model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import copy
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(description="Info to load submodel")
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
|
||||
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["model_loader_output"] = "model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Pipeline
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
#fmt: on
|
||||
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora_name: str = Field(description="Lora model name")
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to unet")
|
||||
|
||||
if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras):
|
||||
raise Exception(f"Lora \"{self.lora_name}\" already applied to clip")
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
model_name=self.lora_name,
|
||||
model_type=SDModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
237
invokeai/app/invocations/param_easing.py
Normal file
237
invokeai/app/invocations/param_easing.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
|
||||
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||||
logger.debug("start_step: " + str(start_step))
|
||||
logger.debug("end_step: " + str(end_step))
|
||||
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
logger.debug("num_presteps: " + str(num_presteps))
|
||||
logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
logger.debug("prelist: " + str(prelist))
|
||||
logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = list()
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
from .math import IntOutput, FloatOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
@@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
@@ -2,8 +2,8 @@ from typing import Literal
|
||||
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
@@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
|
||||
'prompt',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class PromptCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a collection of prompts"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
||||
|
||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
||||
count: int = Field(description="The size of the prompt collection")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
|
||||
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = Field(
|
||||
default=False, description="Whether to use the combinatorial generator"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||
else:
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
|
||||
@@ -2,21 +2,23 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
@@ -26,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
@@ -39,18 +39,17 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
@@ -4,22 +4,22 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
@@ -43,18 +41,17 @@ class UpscaleInvocation(BaseInvocation):
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
@@ -2,31 +2,74 @@ from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
try:
|
||||
ImageType(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
- INTERNAL: The resource was created by the application.
|
||||
- EXTERNAL: The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
"""The resource was created by the application."""
|
||||
EXTERNAL = "external"
|
||||
"""The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOriginException(ValueError):
|
||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid resource origin."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image.
|
||||
|
||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
- MASK: The image is a mask image.
|
||||
- CONTROL: The image is a ControlNet control image.
|
||||
- USER: The image is a user-provide image.
|
||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
"""
|
||||
|
||||
GENERAL = "general"
|
||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||
MASK = "mask"
|
||||
"""MASK: The image is a mask image."""
|
||||
CONTROL = "control"
|
||||
"""CONTROL: The image is a ControlNet control image."""
|
||||
USER = "user"
|
||||
"""USER: The image is a user-provide image."""
|
||||
OTHER = "other"
|
||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||
|
||||
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageCategory.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
@@ -37,3 +80,11 @@ class ColorField(BaseModel):
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
93
invokeai/app/models/metadata.py
Normal file
93
invokeai/app/models/metadata.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional, Union, List
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""
|
||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||
of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
"""
|
||||
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||
won't add any fields that are not already defined, but other a different metadata service
|
||||
implementation might.
|
||||
"""
|
||||
|
||||
type: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The type of the ancestor node of the image output node.",
|
||||
)
|
||||
"""The type of the ancestor node of the image output node."""
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
"""The positive conditioning"""
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
"""The negative conditioning"""
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/latents in pixels."
|
||||
)
|
||||
"""Width of the image/latents in pixels"""
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/latents in pixels."
|
||||
)
|
||||
"""Height of the image/latents in pixels"""
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
# cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Union[float, list[float]] = Field(
|
||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
"""The number of steps used for inference"""
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
"""The scheduler used for inference"""
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
"""The model used for inference"""
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/latents-to-latents.",
|
||||
)
|
||||
"""The strength used for image-to-image/latents-to-latents."""
|
||||
latents: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial latents."
|
||||
)
|
||||
"""The ID of the initial latents"""
|
||||
vae: Optional[StrictStr] = Field(
|
||||
default=None, description="The VAE used for decoding."
|
||||
)
|
||||
"""The VAE used for decoding"""
|
||||
unet: Optional[StrictStr] = Field(
|
||||
default=None, description="The UNet used dor inference."
|
||||
)
|
||||
"""The UNet used dor inference"""
|
||||
clip: Optional[StrictStr] = Field(
|
||||
default=None, description="The CLIP Encoder used for conditioning."
|
||||
)
|
||||
"""The CLIP Encoder used for conditioning"""
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||
)
|
||||
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
||||
254
invokeai/app/services/board_image_record_storage.py
Normal file
254
invokeai/app/services/board_image_record_storage.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
from invokeai.app.services.board_record_storage import BoardRecord
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
|
||||
class BoardImageRecordStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_count_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> int:
|
||||
"""Gets the number of images for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
142
invokeai/app/services/board_images.py
Normal file
142
invokeai/app/services/board_images.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardImagesServiceABC(ABC):
|
||||
"""High-level service for board-image relationship management."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardImagesServiceDependencies:
|
||||
"""Service dependencies for the BoardImagesService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardImagesService(BoardImagesServiceABC):
|
||||
_services: BoardImagesServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardImagesServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
image_records = self._services.board_image_records.get_images_for_board(
|
||||
board_id
|
||||
)
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
board_id,
|
||||
),
|
||||
image_records.items,
|
||||
)
|
||||
)
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=image_records.offset,
|
||||
limit=image_records.limit,
|
||||
total=image_records.total,
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
||||
) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={'cover_image_name'}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
||||
329
invokeai/app/services/board_record_storage.py
Normal file
329
invokeai/app/services/board_record_storage.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's new cover image."
|
||||
)
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Gets a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
"""Updates a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
# Get the total number of boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](
|
||||
items=boards, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
185
invokeai/app/services/boards.py
Normal file
185
invokeai/app/services/boards.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from logging import Logger
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_images import board_record_to_dto
|
||||
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardChanges,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
"""High-level service for board management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Gets a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
"""Updates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> None:
|
||||
"""Deletes a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardServiceDependencies:
|
||||
"""Service dependencies for the BoardService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
_services: BoardServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](
|
||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||
)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
||||
539
invokeai/app/services/config.py
Normal file
539
invokeai/app/services/config.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
'''Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||
categories returned by `invokeai --help`. The file looks like this:
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
autoconvert_dir: null
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
Memory/Performance:
|
||||
xformers_enabled: false
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_loaded_models: 4
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
nsfw_checker: true
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 8081
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
OmegaConf dictionary object initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=omegaconf)
|
||||
|
||||
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||
at initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf.parse_args(argv=['--xformers_enabled'])
|
||||
|
||||
It is also possible to set a value at initialization time. However, if
|
||||
you call parse_args() it may be overwritten.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
# False
|
||||
|
||||
|
||||
To avoid this, use `get_config()` to retrieve the application-wide
|
||||
configuration object. This will retain any properties set at object
|
||||
creation time:
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
# True
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<setting>", as in:
|
||||
|
||||
export INVOKEAI_port=8080
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage at the top level file:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
In most cases, you will want to create a single InvokeAIAppConfig
|
||||
object for the entire application. The InvokeAIAppConfig.get_config() function
|
||||
does this:
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args() # read values from the command line/config file
|
||||
print(config.root)
|
||||
|
||||
# Subclassing
|
||||
|
||||
If you wish to create a similar class, please subclass the
|
||||
`InvokeAISettings` class and define a Literal field named "type",
|
||||
which is set to the desired top-level name. For example, to create a
|
||||
"InvokeBatch" configuration, define like this:
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||
two configs are kept in separate sections of the config file:
|
||||
|
||||
# invokeai.yaml
|
||||
|
||||
InvokeBatch:
|
||||
Resources:
|
||||
node_count: 1
|
||||
cpu_count: 8
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
|
||||
def to_yaml(self)->str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)['type'])[0]
|
||||
field_dict = dict({type:dict()})
|
||||
for name,field in self.__fields__.items():
|
||||
if name in cls._excluded():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self,name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if 'type' in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||
|
||||
initconf = cls.initconf.get(settings_stanza) \
|
||||
if cls.initconf and settings_stanza in cls.initconf \
|
||||
else OmegaConf.create()
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key,value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category","Uncategorized")
|
||||
env_name = env_prefix + '_' + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'Uncategorized'
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
return ['type','initconf']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs='*',
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
|
||||
#fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
|
||||
#fmt: on
|
||||
|
||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||
'''
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
'''
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
super().parse_args(argv)
|
||||
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||
|
||||
@classmethod
|
||||
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
'''
|
||||
if cls.singleton_config is None \
|
||||
or type(cls.singleton_config)!=cls \
|
||||
or (kwargs and cls.singleton_init != kwargs):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
if self.root:
|
||||
return Path(self.root).expanduser()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
Alias for above.
|
||||
'''
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def init_file_path(self)->Path:
|
||||
'''
|
||||
Path to invokeai.yaml
|
||||
'''
|
||||
return self._resolve(INIT_FILE)
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self)->Path:
|
||||
'''
|
||||
Path to the invokeai.db file.
|
||||
'''
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
Path to models configuration file.
|
||||
'''
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self)->Path:
|
||||
'''
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
'''
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self)->Path:
|
||||
'''
|
||||
Path to the models directory
|
||||
'''
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision=='float32'
|
||||
|
||||
@property
|
||||
def disable_xformers(self)->bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self)->bool:
|
||||
"""Return true if patchmatch true"""
|
||||
return self.patchmatch
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
'''
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
'''
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
@@ -1,9 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@@ -101,3 +102,53 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_started (
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -60,6 +60,33 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
from typing import Optional, Union, List, get_args
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
else:
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
# If the type is a List
|
||||
if get_origin(t) is list:
|
||||
return True
|
||||
|
||||
# If the type is a Union
|
||||
elif t_args:
|
||||
# Check if any of the types in the Union is a List
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
@@ -85,7 +112,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
if not issubclass(from_type, to_type):
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@@ -135,6 +163,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
@@ -162,6 +191,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(
|
||||
@@ -361,7 +391,7 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
@@ -372,41 +402,41 @@ class Graph(BaseModel):
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
@@ -692,7 +722,11 @@ class Graph(BaseModel):
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
@@ -711,6 +745,13 @@ class Graph(BaseModel):
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.items()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||
) -> nx.DiGraph:
|
||||
@@ -816,11 +857,9 @@ class GraphExecutionState(BaseModel):
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
|
||||
# TODO: prepare multiple nodes at once?
|
||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
||||
# prepared_id = self._prepare()
|
||||
|
||||
if prepared_id is not None:
|
||||
# Prepare as many nodes as we can
|
||||
while prepared_id is not None:
|
||||
prepared_id = self._prepare()
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
@@ -967,14 +1006,30 @@ class GraphExecutionState(BaseModel):
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
|
||||
# Find next unprepared node where all source nodes are executed
|
||||
# Find next node that:
|
||||
# - was not already prepared
|
||||
# - is not an iterate node whose inputs have not been executed
|
||||
# - does not have an unexecuted iterate ancestor
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node_id = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
# exclude nodes that have already been prepared
|
||||
if n not in self.source_prepared_mapping
|
||||
and all((e[0] in self.executed for e in g.in_edges(n)))
|
||||
# exclude iterate nodes whose inputs have not been executed
|
||||
and not (
|
||||
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
|
||||
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
|
||||
)
|
||||
# exclude nodes who have unexecuted iterate ancestors
|
||||
and not any(
|
||||
(
|
||||
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
|
||||
and a not in self.executed # ...that is not executed
|
||||
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
|
||||
)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -1071,9 +1126,22 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the deepest node that is ready to be executed"""
|
||||
g = self.execution_graph.nx_graph()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if next_node is None:
|
||||
return None
|
||||
|
||||
|
||||
186
invokeai/app/services/image_file_storage.py
Normal file
186
invokeai/app/services/image_file_storage.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||
# 500 internal server error. I don't like having this method on the service.
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: Path
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str | Path):
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
|
||||
|
||||
# Validate required output folders at launch
|
||||
self.__validate_storage_folders()
|
||||
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
except FileNotFoundError as e:
|
||||
raise ImageFileNotFoundException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
self.__validate_storage_folders()
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", metadata.json())
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if image_path.exists():
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||
|
||||
if thumbnail_path.exists():
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
|
||||
path = self.__output_folder / image_name
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
path = self.__thumbnails_folder / thumbnail_name
|
||||
|
||||
return path
|
||||
|
||||
def validate_path(self, path: str | Path) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
"""Checks if the required output folders exist and create them if they don't"""
|
||||
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
|
||||
for folder in folders:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __get_cache(self, image_name: Path) -> PILImageType | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
475
invokeai/app/services/image_record_storage.py
Normal file
475
invokeai/app/services/image_record_storage.py
Normal file
@@ -0,0 +1,475 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
offset: int = Field(description="Offset from which to retrieve items")
|
||||
limit: int = Field(description="Limit of items to get")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
board_id TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = """--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
images_params.append(limit)
|
||||
images_params.append(offset)
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(
|
||||
self, board_id: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
@@ -1,274 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, List
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from send2trash import send2trash
|
||||
from invokeai.app.api.models.images import (
|
||||
ImageResponse,
|
||||
ImageResponseMetadata,
|
||||
SavedImage,
|
||||
)
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import (
|
||||
InvokeAIMetadata,
|
||||
MetadataServiceBase,
|
||||
build_invokeai_metadata_pnginfo,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a paginated list of images."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_uri(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the external URI to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> SavedImage:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
"""Creates a unique contextual image filename."""
|
||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
__metadata_service: MetadataServiceBase
|
||||
|
||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
self.__metadata_service = metadata_service
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
|
||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=self.get_uri(image_type, filename),
|
||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||
metadata=ImageResponseMetadata(
|
||||
created=int(os.path.getctime(path)),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = PILImage.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", basename
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def get_uri(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
thumbnail_basename = get_thumbnail_name(basename)
|
||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
||||
else:
|
||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
||||
|
||||
return uri
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> SavedImage:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
# TODO: Reading the image and then saving it strips the metadata...
|
||||
if metadata:
|
||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path) # this saved image has an empty info
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
|
||||
return SavedImage(
|
||||
image_name=image_name,
|
||||
thumbnail_name=thumbnail_name,
|
||||
created=int(os.path.getctime(image_path)),
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
354
invokeai/app/services/images.py
Normal file
354
invokeai/app/services/images.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
ImageFileStorageBase,
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
"""Updates an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
image_records: ImageRecordStorageBase
|
||||
image_files: ImageFileStorageBase
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
names: NameServiceBase
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.image_records = image_record_storage
|
||||
self.image_files = image_file_storage
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.names = names
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self._services.names.create_image_name()
|
||||
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self._services.image_records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.image_files.save(
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem saving image record and file")
|
||||
raise e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.image_records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.image_records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.image_files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.image_files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(
|
||||
r.image_name
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Union[ImageMetadata, None]:
|
||||
"""Get the metadata for a node."""
|
||||
metadata = None
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
session = self._services.graph_execution_manager.get(session_id)
|
||||
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||
|
||||
return metadata
|
||||
@@ -1,54 +1,68 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing import types
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.restoration_services import RestorationServices
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.config import InvokeAISettings
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .latent_storage import LatentsStorageBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
board_images: "BoardImagesServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAISettings"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
images: "ImageServiceABC"
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
queue: "InvocationQueueABC"
|
||||
restoration: "RestorationServices"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
self,
|
||||
board_images: "BoardImagesServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAISettings",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
images: "ImageServiceABC",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
queue: "InvocationQueueABC",
|
||||
restoration: "RestorationServices",
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.graph_library = graph_library
|
||||
self.images = images
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.processor = processor
|
||||
self.queue = queue
|
||||
self.restoration = restoration
|
||||
|
||||
@@ -22,7 +22,8 @@ class Invoker:
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
@@ -16,7 +15,7 @@ class LatentsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -47,8 +46,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.set(name, data)
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
@@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str
|
||||
__output_folder: str | Path
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
def __init__(self, output_folder: str | Path):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
os.remove(latent_path)
|
||||
latent_path.unlink()
|
||||
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
@@ -1,105 +1,142 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Union
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.image import ImageType, is_image_type
|
||||
|
||||
|
||||
class MetadataImageField(TypedDict):
|
||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||
|
||||
image_type: ImageType
|
||||
image_name: str
|
||||
|
||||
|
||||
class MetadataLatentsField(TypedDict):
|
||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||
|
||||
latents_name: str
|
||||
|
||||
|
||||
class MetadataColorField(TypedDict):
|
||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
a: int
|
||||
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
|
||||
]
|
||||
|
||||
|
||||
class InvokeAIMetadata(TypedDict, total=False):
|
||||
"""InvokeAI-specific metadata format."""
|
||||
|
||||
session_id: Optional[str]
|
||||
node: Optional[NodeMetadata]
|
||||
|
||||
|
||||
def build_invokeai_metadata_pnginfo(
|
||||
metadata: InvokeAIMetadata | None,
|
||||
) -> PngImagePlugin.PngInfo:
|
||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
|
||||
return pnginfo
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
pass
|
||||
"""Handles building metadata for nodes, images, and outputs."""
|
||||
|
||||
@abstractmethod
|
||||
def build_metadata(
|
||||
self, session_id: str, node: BaseModel
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""Builds an ImageMetadata object for a node."""
|
||||
pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
class CoreMetadataService(MetadataServiceBase):
|
||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||
"""The ancestor types that contain the core metadata"""
|
||||
|
||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||
"""The core metadata parameters in the ancestor types"""
|
||||
|
||||
try:
|
||||
info = image.info.get("invokeai")
|
||||
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||
"""The core metadata parameters in the noise node"""
|
||||
|
||||
if type(info) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(info)
|
||||
|
||||
if type(loaded_metadata) is not dict:
|
||||
return None
|
||||
|
||||
if len(loaded_metadata.items()) == 0:
|
||||
return None
|
||||
|
||||
return loaded_metadata
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Retrieves an image's metadata as a dict"""
|
||||
loaded_metadata = self._load_metadata(image)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
metadata = self._build_metadata_from_graph(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||
"""
|
||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||
|
||||
Parameters:
|
||||
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||
have the same data as the execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||
"""
|
||||
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
# If the node type is one of the core metadata node types, return its id
|
||||
if node.get("type") in self._ANCESTOR_TYPES:
|
||||
return node.get("id")
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = self._find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
def _get_additional_metadata(
|
||||
self, graph: Graph, node_id: str
|
||||
) -> Union[dict[str, Any], None]:
|
||||
"""
|
||||
Returns additional metadata for a given node.
|
||||
|
||||
Parameters:
|
||||
graph (Graph): The execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: A dictionary of additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# Prompt
|
||||
if dest_field == "positive_conditioning":
|
||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||
# Negative prompt
|
||||
if dest_field == "negative_conditioning":
|
||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||
# Seed, width and height
|
||||
if dest_field == "noise":
|
||||
for field in self._NOISE_FIELDS:
|
||||
metadata[field] = source_node_dict.get(field)
|
||||
return metadata
|
||||
|
||||
def _build_metadata_from_graph(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""
|
||||
Builds an ImageMetadata object for a node.
|
||||
|
||||
Parameters:
|
||||
session (GraphExecutionState): The session.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
ImageMetadata: The metadata for the node.
|
||||
"""
|
||||
|
||||
# We need to do all the traversal on the execution graph
|
||||
graph = session.execution_graph
|
||||
|
||||
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||
|
||||
# If no ancestor was found, return an empty ImageMetadata object
|
||||
if ancestor_id is None:
|
||||
return ImageMetadata()
|
||||
|
||||
ancestor_node = graph.get_node(ancestor_id)
|
||||
|
||||
# Grab all the core metadata from the ancestor node
|
||||
ancestor_metadata = {
|
||||
param: val
|
||||
for param, val in ancestor_node.dict().items()
|
||||
if param in self._ANCESTOR_PARAMS
|
||||
}
|
||||
|
||||
# Get this image's prompts and noise parameters
|
||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
ancestor_metadata.update(addl_metadata)
|
||||
|
||||
return ImageMetadata(**ancestor_metadata)
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# creating the model manager
|
||||
try:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = 'float16' if config.precision=='float16' \
|
||||
else 'float32' if config.precision=='float32' \
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.conf),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
||||
363
invokeai/app/services/model_manager_service.py
Normal file
363
invokeai/app/services/model_manager_service.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
ModelManager,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
)
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
if not config_file.exists():
|
||||
raise IOError(f"The file {config_file} could not be found.")
|
||||
|
||||
logger.debug(f'config file={config_file}')
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == 'float32' else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# if we are called from within a node, then we get to emit
|
||||
# load start and complete events
|
||||
if node and context:
|
||||
self._emit_load_event(
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
|
||||
if node and context:
|
||||
self._emit_load_event(
|
||||
node=node,
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
node,
|
||||
context,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node.dict(),
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
62
invokeai/app/services/models/board_record.py
Normal file
62
invokeai/app/services/models/board_record.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the cover image of the board."
|
||||
)
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's cover image."
|
||||
)
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
"""Deserializes a board record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
board_id = board_dict.get("board_id", "unknown")
|
||||
board_name = board_dict.get("board_name", "unknown")
|
||||
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
board_name=board_name,
|
||||
cover_image_name=cover_image_name,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
)
|
||||
151
invokeai/app/services/models/image_record.py
Normal file
151
invokeai/app/services/models/image_record.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
"""The category of the image."""
|
||||
width: int = Field(description="The width of the image in px.")
|
||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||
height: int = Field(description="The height of the image in px.")
|
||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(
|
||||
description="The updated timestamp of the image."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the image."
|
||||
)
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The session ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The session ID that generated this image, if it is a generated image."""
|
||||
node_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None,
|
||||
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||
)
|
||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
description="The image's new category."
|
||||
)
|
||||
"""The image's new category."""
|
||||
session_id: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
"""The URL of the image's thumbnail."""
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Union[str, None] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
width = image_dict.get("width", 0)
|
||||
height = image_dict.get("height", 0)
|
||||
session_id = image_dict.get("session_id", None)
|
||||
node_id = image_dict.get("node_id", None)
|
||||
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
if raw_metadata is not None:
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
session_id=session_id,
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
30
invokeai/app/services/resource_name.py
Normal file
30
invokeai/app/services/resource_name.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
import uuid
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
|
||||
IMAGE = "image"
|
||||
LATENT = "latent"
|
||||
|
||||
|
||||
class NameServiceBase(ABC):
|
||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
@abstractmethod
|
||||
def create_image_name(self) -> str:
|
||||
"""Creates a name for an image."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleNameService(NameServiceBase):
|
||||
"""Creates image names from UUIDs."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
||||
@@ -16,13 +16,14 @@ class RestorationServices:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
restoration = Restoration()
|
||||
if args.restore:
|
||||
# TODO: redo for new model structure
|
||||
if False and args.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
if False and args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
logger.info("Upscaling disabled")
|
||||
|
||||
@@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
|
||||
25
invokeai/app/services/urls.py
Normal file
25
invokeai/app/services/urls.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_basename}"
|
||||
15
invokeai/app/util/metaenum.py
Normal file
15
invokeai/app/util/metaenum.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class MetaEnum(EnumMeta):
|
||||
"""Metaclass to support additional features in Enums.
|
||||
|
||||
- `in` operator support: `'value' in MyEnum -> bool`
|
||||
"""
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
@@ -6,6 +6,14 @@ def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .model_management import (
|
||||
ModelManager, ModelCache, BaseModelType,
|
||||
ModelType, SubModelType, ModelInfo
|
||||
)
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,6 @@ from .base import (
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint,
|
||||
Generator,
|
||||
|
||||
@@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
@@ -75,17 +74,21 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
def generate(
|
||||
self,
|
||||
conditioning: tuple,
|
||||
scheduler,
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
@@ -111,51 +114,46 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
model_info = self.model_info
|
||||
model_name = model_info['model_name']
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model_hash = model_info['hash']
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
model_name = model_info.name
|
||||
model_hash = model_info.hash
|
||||
with model_info.context as model:
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(
|
||||
conditioning=conditioning,
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
@@ -168,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
@@ -191,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .txt2img import Txt2Img
|
||||
return Txt2Img
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
@@ -251,36 +228,17 @@ class Inpaint(Img2Img):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
# ------------------------------------
|
||||
class Embiggen(Txt2Img):
|
||||
def generate(
|
||||
self,
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
precision: str
|
||||
model: DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
||||
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.latent_channels = model.unet.config.in_channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
@@ -291,7 +249,7 @@ class Generator:
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
def get_make_image(self, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@@ -307,7 +265,6 @@ class Generator:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
@@ -332,7 +289,6 @@ class Generator:
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
|
||||
@@ -1,559 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.embiggen descends from .generator
|
||||
and generates with .generator.img2img
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
for _ in trange(iterations, desc="Generating"):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert (
|
||||
not sampler.uses_inpainting_model()
|
||||
), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model, self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert("RGB")
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth * embiggen[0])
|
||||
initsuperheight = round(initsuperheight * embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert (
|
||||
emb_tiles_x > 1 or emb_tiles_y > 1
|
||||
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = (
|
||||
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
|
||||
)
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
# Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
|
||||
# Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(
|
||||
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
|
||||
(0, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerAA.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = (
|
||||
np.iinfo(np.uint32).max - 1
|
||||
) # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=conditioning,
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=None, # called only after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_image=newinitimage, # notice that init_image is different from init_img
|
||||
mask_image=None,
|
||||
strength=strength,
|
||||
clear_cuda_cache=clear_cuda_cache,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
|
||||
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
|
||||
):
|
||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert("RGBA"), (0, 0)
|
||||
)
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert("RGBA")
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
|
||||
# end of function declaration
|
||||
return make_image
|
||||
@@ -22,7 +22,6 @@ class Img2Img(Generator):
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
|
||||
@@ -161,9 +161,7 @@ class Inpaint(Img2Img):
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@@ -177,8 +175,6 @@ class Inpaint(Img2Img):
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@@ -196,15 +192,13 @@ class Inpaint(Img2Img):
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise, seed)
|
||||
result = make_image(seam_noise, seed=None)
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@@ -306,7 +300,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
@@ -345,9 +338,7 @@ class Inpaint(Img2Img):
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@@ -360,8 +351,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
||||
@@ -1,209 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
|
||||
from ..stable_diffusion import PostprocessingSettings
|
||||
from .base import Generator
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt: str,
|
||||
sampler,
|
||||
steps: int,
|
||||
cfg_scale: float,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width: int,
|
||||
height: int,
|
||||
strength: float,
|
||||
step_callback: Optional[Callable] = None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int):
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(
|
||||
resized_latents, override_perlin=True
|
||||
)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
|
||||
new_postprocessing_settings = replace(
|
||||
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
|
||||
)
|
||||
new_postprocessing_settings = replace(
|
||||
new_postprocessing_settings, v_symmetry_time_pct=None
|
||||
)
|
||||
new_conditioning_data = replace(
|
||||
conditioning_data, postprocessing_settings=new_postprocessing_settings
|
||||
)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback,
|
||||
)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self, width, height, scale=True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
# Scale the input width and height for the initial generation
|
||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
||||
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = (
|
||||
dimension * dimension
|
||||
) # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
init_width = init_height * aspect
|
||||
else:
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(
|
||||
math.floor(init_width), math.floor(init_height)
|
||||
)
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (
|
||||
1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor,
|
||||
)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
tensor = torch.empty(size=shape, device="cpu")
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.globals defines a small number of global variables that would
|
||||
otherwise have to be passed through long and complex call chains.
|
||||
|
||||
It defines a Namespace object named "Globals" that contains
|
||||
the attributes:
|
||||
|
||||
- root - the root directory under which "models" and "outputs" can be found
|
||||
- initfile - path to the initialization file
|
||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||
- always_use_cpu - force use of CPU even if GPU is available
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
# Where to look for the initialization file and other key components
|
||||
Globals.initfile = "invokeai.init"
|
||||
Globals.models_file = "models.yaml"
|
||||
Globals.models_dir = "models"
|
||||
Globals.config_dir = "configs"
|
||||
Globals.autoscan_dir = "weights"
|
||||
Globals.converted_ckpts_dir = "converted_ckpts"
|
||||
|
||||
# Set the default root directory. This can be overwritten by explicitly
|
||||
# passing the `--root <directory>` argument on the command line.
|
||||
# logic is:
|
||||
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
|
||||
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
|
||||
# 3) use ~/invokeai
|
||||
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
|
||||
):
|
||||
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
|
||||
else:
|
||||
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
|
||||
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
||||
Globals.always_use_cpu = False
|
||||
|
||||
# Whether the internet is reachable for dynamic downloads
|
||||
# The CLI will test connectivity at startup time.
|
||||
Globals.internet_available = True
|
||||
|
||||
# Whether to disable xformers
|
||||
Globals.disable_xformers = False
|
||||
|
||||
# Low-memory tradeoff for guidance calculations.
|
||||
Globals.sequential_guidance = False
|
||||
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
# whether we should convert ckpt files into diffusers models on the fly
|
||||
Globals.ckpt_convert = True
|
||||
|
||||
# logging tokenization everywhere
|
||||
Globals.log_tokenization = False
|
||||
|
||||
|
||||
def global_config_file() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
|
||||
def global_config_dir() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
||||
def global_models_dir() -> Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
|
||||
def global_autoscan_dir() -> Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
|
||||
def global_converted_ckpts_dir() -> Path:
|
||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
|
||||
def global_set_root(root_dir: Union[str, Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
|
||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
"""
|
||||
Returns Path to the model cache directory. If a subdirectory
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for Hugging Face-style conventions. Currently, Hugging Face has
|
||||
moved all models into the "hub" subfolder, so for any pretrained
|
||||
HF model, use:
|
||||
global_cache_dir('hub')
|
||||
|
||||
The legacy location for transformers used to be global_cache_dir('transformers')
|
||||
and global_cache_dir('diffusers') for diffusers.
|
||||
"""
|
||||
home: str = os.getenv("HF_HOME")
|
||||
|
||||
if home is None:
|
||||
home = os.getenv("XDG_CACHE_HOME")
|
||||
|
||||
if home is not None:
|
||||
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
|
||||
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
||||
home += os.sep + "huggingface"
|
||||
|
||||
if home is not None:
|
||||
return Path(home, subdir)
|
||||
else:
|
||||
return Path(Globals.root, "models", subdir)
|
||||
@@ -6,7 +6,8 @@ be suppressed or deferred
|
||||
"""
|
||||
import numpy as np
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
@@ -23,7 +24,7 @@ class PatchMatch:
|
||||
def _load_patch_match(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
if Globals.try_patchmatch:
|
||||
if config.try_patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
|
||||
@@ -33,11 +33,11 @@ from PIL import Image, ImageOps
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
@@ -88,10 +88,10 @@ class Txt2Mask(object):
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -12,17 +12,17 @@ print("Loading Python libraries...\n",file=sys.stderr)
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import HfFolder
|
||||
@@ -35,57 +35,53 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ...frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from ...frontend.install.widgets import (
|
||||
from invokeai.app.services.config import (
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
)
|
||||
from ..args import PRECISION_CHOICES, Args
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
|
||||
from .model_install_backend import (
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
default_dataset,
|
||||
download_from_hf,
|
||||
hf_download_with_resume,
|
||||
recommended_datasets,
|
||||
UserSelections,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 135
|
||||
MIN_LINES = 45
|
||||
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
# Place frequently-used startup commands here, one or more per line.
|
||||
# Examples:
|
||||
# --outdir=D:\data\images
|
||||
# --no-nsfw_checker
|
||||
# --web --host=0.0.0.0
|
||||
# --steps=20
|
||||
# -Ak_euler_a -C10.0
|
||||
"""
|
||||
|
||||
logger=None
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@@ -96,14 +92,13 @@ If you installed manually from source or with 'pip install': activate the virtua
|
||||
then run one of the following commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
invokeai --web # (connect to http://localhost:9090)
|
||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
||||
invokeai-web
|
||||
|
||||
Command-line interface:
|
||||
Command-line client:
|
||||
invokeai
|
||||
|
||||
If you installed using an installation script, run:
|
||||
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
Add the '--help' argument to see all of the command-line switches available for use.
|
||||
"""
|
||||
@@ -215,16 +210,11 @@ def download_realesrgan():
|
||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
wdn_model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
|
||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
||||
|
||||
|
||||
def download_gfpgan():
|
||||
@@ -243,8 +233,8 @@ def download_gfpgan():
|
||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||
],
|
||||
):
|
||||
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
|
||||
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
||||
model_url, model_dest = model[0], config.root_path / model[1]
|
||||
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -253,8 +243,8 @@ def download_codeformer():
|
||||
model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
|
||||
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
||||
model_dest = config.root_path / "models/codeformer/codeformer.pth"
|
||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -295,7 +285,7 @@ def download_vaes():
|
||||
# first the diffusers version
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
cache_dir=config.cache_dir,
|
||||
)
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
@@ -306,7 +296,7 @@ def download_vaes():
|
||||
if not hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
||||
model_dir=str(config.root_path / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
except Exception as e:
|
||||
@@ -321,25 +311,24 @@ def get_root(root: str = None) -> str:
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
return str(config.root_path)
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(npyscreen.FormMultiPage):
|
||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
for i in [
|
||||
"Configure startup settings. You can come back and change these later.",
|
||||
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
|
||||
"Use cursor arrows to make a checkbox selection, and space to toggle.",
|
||||
]:
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@@ -366,7 +355,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=old_opts.outdir or str(default_output_dir()),
|
||||
value=str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
@@ -381,22 +370,21 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.safety_checker = self.add_widget_intelligent(
|
||||
self.nsfw_checker = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="NSFW checker",
|
||||
value=old_opts.safety_checker,
|
||||
value=old_opts.nsfw_checker,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
for i in [
|
||||
"If you have an account at HuggingFace you may paste your access token here",
|
||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
||||
"See https://huggingface.co/settings/tokens",
|
||||
]:
|
||||
label = """If you have an account at HuggingFace you may optionally paste your access token here
|
||||
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
|
||||
"""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
value=line,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
@@ -435,17 +423,10 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.xformers = self.add_widget_intelligent(
|
||||
self.xformers_enabled = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support if available",
|
||||
value=old_opts.xformers,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.ckpt_convert = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Load legacy checkpoint models into memory as diffusers models",
|
||||
value=old_opts.ckpt_convert,
|
||||
value=old_opts.xformers_enabled,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
@@ -480,19 +461,41 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Directory containing embedding/textual inversion files:",
|
||||
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.embedding_path = self.add_widget_intelligent(
|
||||
self.embedding_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
name=" Textual Inversion Embeddings:",
|
||||
value=str(default_embedding_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lora_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" LoRA and LyCORIS:",
|
||||
value=str(default_lora_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.controlnet_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" ControlNets:",
|
||||
value=str(default_controlnet_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -505,11 +508,11 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
for i in [
|
||||
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
|
||||
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
|
||||
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
|
||||
]:
|
||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@@ -548,7 +551,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
@@ -559,9 +562,9 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
)
|
||||
if not Path(opt.embedding_path).parent.exists():
|
||||
if not Path(opt.embedding_dir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:\n"
|
||||
@@ -576,20 +579,24 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
"embedding_dir",
|
||||
"lora_dir",
|
||||
"controlnet_dir",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
# widget library workaround to make max_loaded_models an int rather than a float
|
||||
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||
|
||||
return new_opts
|
||||
|
||||
@@ -608,6 +615,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
"MAIN",
|
||||
editOptsForm,
|
||||
name="InvokeAI Startup Options",
|
||||
cycle_widgets=True,
|
||||
)
|
||||
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
|
||||
self.model_select = self.addForm(
|
||||
@@ -615,6 +623,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=True,
|
||||
cycle_widgets=True,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
@@ -628,18 +637,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = Args().parse_args([])
|
||||
outdir = Path(opts.outdir)
|
||||
if not outdir.is_absolute():
|
||||
opts.outdir = str(Globals.root / opts.outdir)
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
if not init_file.exists():
|
||||
opts.safety_checker = True
|
||||
opts.nsfw_checker = True
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
return Namespace(
|
||||
starter_models=default_dataset()
|
||||
def default_user_selections(program_opts: Namespace) -> UserSelections:
|
||||
return UserSelections(
|
||||
install_models=default_dataset()
|
||||
if program_opts.default_only
|
||||
else recommended_datasets()
|
||||
if program_opts.yes_to_all
|
||||
@@ -647,26 +652,27 @@ def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: str, yes_to_all: bool = False):
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
|
||||
for name in (
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"databases",
|
||||
"loras",
|
||||
"controlnets",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
|
||||
configs_src = Path(configs.__path__[0])
|
||||
configs_dest = Path(root) / "configs"
|
||||
configs_dest = root / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
@@ -677,8 +683,17 @@ def run_console_ui(
|
||||
) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
# The third argument is needed in the Windows 11 environment to
|
||||
# launch a console window running this program.
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
@@ -690,70 +705,66 @@ def run_console_ui(
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
"""
|
||||
Update the invokeai.init file with values from opts Namespace
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
"""
|
||||
# touch file if it doesn't exist
|
||||
if not init_file.exists():
|
||||
with open(init_file, "w") as f:
|
||||
f.write(INIT_FILE_PREAMBLE)
|
||||
|
||||
# We want to write in the changed arguments without clobbering
|
||||
# any other initialization values the user has entered. There is
|
||||
# no good way to do this because of the one-way nature of
|
||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
||||
# initfile needs to be replaced with a fully structured format
|
||||
# such as yaml; this is a hack that will work much of the time
|
||||
args_to_skip = re.compile(
|
||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
||||
)
|
||||
# fix windows paths
|
||||
opts.outdir = opts.outdir.replace("\\", "/")
|
||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
||||
new_file = f"{init_file}.new"
|
||||
try:
|
||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
||||
with open(new_file, "w") as out_file:
|
||||
for line in lines:
|
||||
if len(line) > 0 and not args_to_skip.match(line):
|
||||
out_file.write(line + "\n")
|
||||
out_file.write(
|
||||
f"""
|
||||
--outdir={opts.outdir}
|
||||
--embedding_path={opts.embedding_path}
|
||||
--precision={opts.precision}
|
||||
--max_loaded_models={int(opts.max_loaded_models)}
|
||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
||||
--{'no-' if not opts.xformers else ''}xformers
|
||||
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
|
||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
||||
"""
|
||||
)
|
||||
except OSError as e:
|
||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
||||
|
||||
os.replace(new_file, init_file)
|
||||
|
||||
if opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
# this will load current settings
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key,value in opts.__dict__.items():
|
||||
if hasattr(new_config,key):
|
||||
setattr(new_config,key,value)
|
||||
|
||||
with open(init_file,'w', encoding='utf-8') as file:
|
||||
file.write(new_config.to_yaml())
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return Globals.root / "outputs"
|
||||
|
||||
return config.root_path / "outputs"
|
||||
|
||||
# -------------------------------------
|
||||
def default_embedding_dir() -> Path:
|
||||
return Globals.root / "embeddings"
|
||||
return config.root_path / "embeddings"
|
||||
|
||||
# -------------------------------------
|
||||
def default_lora_dir() -> Path:
|
||||
return config.root_path / "loras"
|
||||
|
||||
# -------------------------------------
|
||||
def default_controlnet_dir() -> Path:
|
||||
return config.root_path / "controlnets"
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
opt.hf_token = HfFolder.get_token()
|
||||
write_opts(opt, initfile)
|
||||
|
||||
# -------------------------------------
|
||||
# Here we bring in
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format:Path):
|
||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
for attr in fields:
|
||||
if hasattr(old,attr):
|
||||
setattr(new,attr,getattr(old,attr))
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
new.nsfw_checker = old.safety_checker
|
||||
new.xformers_enabled = old.xformers
|
||||
new.conf_path = old.conf
|
||||
new.embedding_dir = old.embedding_path
|
||||
|
||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||
outfile.write(new.to_yaml())
|
||||
|
||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
@@ -809,8 +820,13 @@ def main():
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||
invoke_args = []
|
||||
if opt.root:
|
||||
invoke_args.extend(['--root',opt.root])
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
errors = set()
|
||||
|
||||
@@ -818,19 +834,26 @@ def main():
|
||||
models_to_download = default_user_selections(opt)
|
||||
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
init_file = Path(Globals.root, Globals.initfile)
|
||||
if not init_file.exists() or not global_config_file().exists():
|
||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||
old_init_file = config.root_path / 'invokeai.init'
|
||||
new_init_file = config.root_path / 'invokeai.yaml'
|
||||
if old_init_file.exists() and not new_init_file.exists():
|
||||
print('** Migrating invokeai.init to invokeai.yaml')
|
||||
migrate_init_file(old_init_file)
|
||||
# Load new init file into config
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, init_file)
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, init_file)
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
print(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
@@ -840,7 +863,7 @@ def main():
|
||||
if opt.skip_support_models:
|
||||
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
|
||||
else:
|
||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_sd1_clip()
|
||||
download_sd2_clip()
|
||||
@@ -858,6 +881,8 @@ def main():
|
||||
process_and_execute(opt, models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input('Press any key to continue...')
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
396
invokeai/backend/install/legacy_arg_parsing.py
Normal file
396
invokeai/backend/install/legacy_arg_parsing.py
Normal file
@@ -0,0 +1,396 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
from argparse import ArgumentParser
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"lms_k",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
"auto",
|
||||
"float32",
|
||||
"autocast",
|
||||
"float16",
|
||||
]
|
||||
|
||||
class FileArgumentParser(ArgumentParser):
|
||||
"""
|
||||
Supports reading defaults from an init file.
|
||||
"""
|
||||
def convert_arg_line_to_args(self, arg_line):
|
||||
return shlex.split(arg_line, comments=True)
|
||||
|
||||
|
||||
legacy_parser = FileArgumentParser(
|
||||
description=
|
||||
"""
|
||||
Generate images using Stable Diffusion.
|
||||
Use --web to launch the web interface.
|
||||
Use --from_file to load prompts from a file path or standard input ("-").
|
||||
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
||||
Other command-line arguments are defaults that can usually be overridden
|
||||
prompt the command prompt.
|
||||
""",
|
||||
fromfile_prefix_chars='@',
|
||||
)
|
||||
general_group = legacy_parser.add_argument_group('General')
|
||||
model_group = legacy_parser.add_argument_group('Model selection')
|
||||
file_group = legacy_parser.add_argument_group('Input/output')
|
||||
web_server_group = legacy_parser.add_argument_group('Web server')
|
||||
render_group = legacy_parser.add_argument_group('Rendering')
|
||||
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
||||
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
||||
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
general_group.add_argument(
|
||||
'--version','-V',
|
||||
action='store_true',
|
||||
help='Print InvokeAI version number'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--root_dir',
|
||||
default=None,
|
||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--config',
|
||||
'-c',
|
||||
'-config',
|
||||
dest='conf',
|
||||
default='./configs/models.yaml',
|
||||
help='Path to configuration file for alternate models.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--model',
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--weight_dirs',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0,9),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'-F',
|
||||
'--full_precision',
|
||||
dest='full_precision',
|
||||
action='store_true',
|
||||
help='Deprecated way to set --precision=float32',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--max_loaded_models',
|
||||
dest='max_loaded_models',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--free_gpu_mem',
|
||||
dest='free_gpu_mem',
|
||||
action='store_true',
|
||||
help='Force free gpu memory before final decoding',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--sequential_guidance',
|
||||
dest='sequential_guidance',
|
||||
action='store_true',
|
||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||
"at the expense of speed",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--xformers',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable/disable xformers support (default enabled if installed)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--always_use_cpu",
|
||||
dest="always_use_cpu",
|
||||
action="store_true",
|
||||
help="Force use of CPU even if GPU is available"
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--precision',
|
||||
dest='precision',
|
||||
type=str,
|
||||
choices=PRECISION_CHOICES,
|
||||
metavar='PRECISION',
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default='auto',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--ckpt_convert',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='ckpt_convert',
|
||||
default=True,
|
||||
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--internet',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='internet_available',
|
||||
default=True,
|
||||
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--nsfw_checker',
|
||||
'--safety_checker',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='safety_checker',
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoimport',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--from_file',
|
||||
dest='infile',
|
||||
type=str,
|
||||
help='If specified, load prompts from this file',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||
default='outputs',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--prompt_as_dir',
|
||||
'-p',
|
||||
action='store_true',
|
||||
help='Place images in subdirectories named after the prompt.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--fnformat',
|
||||
default='{prefix}.{seed}.png',
|
||||
type=str,
|
||||
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-s',
|
||||
'--steps',
|
||||
type=int,
|
||||
default=50,
|
||||
help='Number of steps'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-W',
|
||||
'--width',
|
||||
type=int,
|
||||
help='Image width, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-H',
|
||||
'--height',
|
||||
type=int,
|
||||
help='Image height, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-C',
|
||||
'--cfg_scale',
|
||||
default=7.5,
|
||||
type=float,
|
||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--sampler',
|
||||
'-A',
|
||||
'-m',
|
||||
dest='sampler_name',
|
||||
type=str,
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar='SAMPLER_NAME',
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default='k_lms',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--log_tokenization',
|
||||
'-t',
|
||||
action='store_true',
|
||||
help='shows how the prompt is split into tokens'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-f',
|
||||
'--strength',
|
||||
type=float,
|
||||
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
'--fit',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||
)
|
||||
|
||||
render_group.add_argument(
|
||||
'--grid',
|
||||
'-g',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='generate a grid'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embedding_directory',
|
||||
'--embedding_path',
|
||||
dest='embedding_path',
|
||||
default='embeddings',
|
||||
type=str,
|
||||
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--lora_directory',
|
||||
dest='lora_path',
|
||||
default='loras',
|
||||
type=str,
|
||||
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embeddings',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--enable_image_debugging',
|
||||
action='store_true',
|
||||
help='Generates debugging image to display'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--karras_max',
|
||||
type=int,
|
||||
default=None,
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||
)
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
'--no_restore',
|
||||
dest='restore',
|
||||
action='store_false',
|
||||
help='Disable face restoration with GFPGAN or codeformer',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--no_upscale',
|
||||
dest='esrgan',
|
||||
action='store_false',
|
||||
help='Disable upscaling with ESRGAN',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_bg_tile',
|
||||
type=int,
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_denoise_str',
|
||||
type=float,
|
||||
default=0.75,
|
||||
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_model_path',
|
||||
type=str,
|
||||
default='./models/gfpgan/GFPGANv1.4.pth',
|
||||
help='Indicates the path to the GFPGAN model',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web',
|
||||
dest='web',
|
||||
action='store_true',
|
||||
help='Start in web server mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web_develop',
|
||||
dest='web_develop',
|
||||
action='store_true',
|
||||
help='Start in web server development mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_verbose",
|
||||
action="store_true",
|
||||
help="Enables verbose logging",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--cors",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--host',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default='9090',
|
||||
help='Web server: Port to listen on'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--certfile',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--keyfile',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--gui',
|
||||
dest='gui',
|
||||
action='store_true',
|
||||
help='Start InvokeAI GUI',
|
||||
)
|
||||
@@ -6,26 +6,30 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass,field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List
|
||||
from typing import List, Dict, Callable
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
from huggingface_hub import hf_hub_url, HfFolder
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from ..model_management import ModelManager
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
@@ -35,6 +39,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
# logger
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
@@ -45,58 +52,98 @@ Config_preamble = """
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
'''Class for listing models to be installed/removed'''
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class UserSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
purge_deleted_models: bool=field(default_factory=list)
|
||||
install_cn_models: List[str] = field(default_factory=list)
|
||||
remove_cn_models: List[str] = field(default_factory=list)
|
||||
install_lora_models: List[str] = field(default_factory=list)
|
||||
remove_lora_models: List[str] = field(default_factory=list)
|
||||
install_ti_models: List[str] = field(default_factory=list)
|
||||
remove_ti_models: List[str] = field(default_factory=list)
|
||||
scan_directory: Path = None
|
||||
autoscan_on_startup: bool=False
|
||||
import_model_paths: str=None
|
||||
|
||||
def default_config_file():
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
|
||||
return config.model_conf_path
|
||||
|
||||
def sd_configs():
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
return config.legacy_conf_path
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
if Datasets:
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path))
|
||||
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
diffusers: ModelInstallList = None,
|
||||
controlnet: ModelInstallList = None,
|
||||
lora: ModelInstallList = None,
|
||||
ti: ModelInstallList = None,
|
||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
model_config_file_callback: Callable[[Path],Path] = None
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
"""
|
||||
access_token = HfFolder.get_token()
|
||||
config_file_path = config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
# prevent circular import here
|
||||
from ..model_management import ModelManager
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
if controlnet:
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
if lora:
|
||||
model_manager.install_lora_models(lora.install_models, access_token=access_token)
|
||||
model_manager.delete_lora_models(lora.remove_models)
|
||||
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
if ti:
|
||||
model_manager.install_ti_models(ti.install_models, access_token=access_token)
|
||||
model_manager.delete_ti_models(ti.remove_models)
|
||||
|
||||
if diffusers:
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("Processing requested deletions")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("Installing requested models")
|
||||
downloaded_paths = download_weight_datasets(
|
||||
models=diffusers.install_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
)
|
||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||
if len(successful) > 0:
|
||||
update_config_file(successful, config_file_path)
|
||||
if len(successful) < len(diffusers.install_models):
|
||||
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
||||
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
@@ -107,12 +154,14 @@ def install_requested_models(
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models) > 0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
logger.info("INSTALLING EXTERNAL MODELS")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
commit_to_conf=config_file_path,
|
||||
config_file_callback = model_config_file_callback,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
@@ -120,17 +169,22 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = "--autoconvert"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
with open(initfile, "r") as input:
|
||||
with open(replacement, "w") as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f"{argument} {directory}"])
|
||||
os.replace(replacement, initfile)
|
||||
update_autoconvert_dir(scan_directory)
|
||||
else:
|
||||
update_autoconvert_dir(None)
|
||||
|
||||
def update_autoconvert_dir(autodir: Path):
|
||||
'''
|
||||
Update the "autoconvert_dir" option in invokeai.yaml
|
||||
'''
|
||||
invokeai_config_path = config.init_file_path
|
||||
conf = OmegaConf.load(invokeai_config_path)
|
||||
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
|
||||
yaml = OmegaConf.to_yaml(conf)
|
||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml)
|
||||
tmpfile.replace(invokeai_config_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
@@ -142,33 +196,21 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
def recommended_datasets() -> List['str']:
|
||||
datasets = set()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
datasets.add(ds)
|
||||
return list(datasets)
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
datasets = set()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
datasets.add(ds)
|
||||
return list(datasets)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -183,14 +225,14 @@ def all_datasets() -> dict:
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print(
|
||||
logger.warning(
|
||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||
)
|
||||
print(f"model.ckpt => {new_name}")
|
||||
logger.warning(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
@@ -203,7 +245,7 @@ def download_weight_datasets(
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
print(f"Downloading {mod}:")
|
||||
logger.info(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
@@ -224,11 +266,10 @@ def _download_repo_or_file(
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
@@ -239,9 +280,12 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
model_class: object, model_name: str, **kwargs
|
||||
):
|
||||
path = global_cache_dir(cache_subdir)
|
||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||
|
||||
path = config.cache_dir
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
@@ -272,10 +316,10 @@ def _download_diffusion_weights(
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
if 'Revision Not Found' in str(e):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
logger.error(str(e))
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
@@ -283,9 +327,13 @@ def _download_diffusion_weights(
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
@@ -305,20 +353,19 @@ def hf_download_with_resume(
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
@@ -331,7 +378,7 @@ def hf_download_with_resume(
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
@@ -356,8 +403,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
logger.warning(
|
||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
@@ -374,16 +421,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
logger.info("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
logger.info(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -417,7 +464,7 @@ def new_config_file_contents(
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
successfully_downloaded[model], start=config.root_dir
|
||||
)
|
||||
stanza["config"] = os.path.normpath(
|
||||
os.path.join(sd_configs(), mod["config"])
|
||||
@@ -450,14 +497,14 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
logger.warning(
|
||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
weights = config.root_dir / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
logger.error(str(e))
|
||||
@@ -1,11 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager,SDModelComponent
|
||||
|
||||
|
||||
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
|
||||
@@ -26,12 +26,15 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache
|
||||
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||
@@ -56,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||
LDMBertConfig,
|
||||
LDMBertModel,
|
||||
)
|
||||
from diffusers.pipelines.paint_by_example import (
|
||||
PaintByExampleImageEncoder,
|
||||
PaintByExamplePipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
@@ -74,6 +73,7 @@ from transformers import (
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
MODEL_ROOT = None
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
@@ -159,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(
|
||||
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
||||
@@ -184,7 +184,6 @@ def assign_to_checkpoint(
|
||||
paths,
|
||||
checkpoint,
|
||||
old_checkpoint,
|
||||
attention_paths_to_split=None,
|
||||
additional_replacements=None,
|
||||
config=None,
|
||||
):
|
||||
@@ -199,35 +198,9 @@ def assign_to_checkpoint(
|
||||
paths, list
|
||||
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape(
|
||||
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
|
||||
)
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if (
|
||||
attention_paths_to_split is not None
|
||||
and new_path in attention_paths_to_split
|
||||
):
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
@@ -246,14 +219,14 @@ def assign_to_checkpoint(
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
elif "to_out.0.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config, image_size: int):
|
||||
@@ -613,16 +586,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
# Extract state dict for VAE. Works both with burnt-in
|
||||
# VAEs, and with standalone VAEs.
|
||||
|
||||
# checkpoint can either be a all-in-one stable diffusion
|
||||
# model, or an isolated vae .ckpt. This tests for
|
||||
# a key that will be present in the all-in-one model
|
||||
# that isn't present in the isolated ckpt.
|
||||
probe_key = "first_stage_model.encoder.conv_in.weight"
|
||||
if probe_key in checkpoint:
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
else:
|
||||
vae_state_dict = checkpoint
|
||||
|
||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
||||
return new_checkpoint
|
||||
|
||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
@@ -842,10 +828,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
|
||||
)
|
||||
|
||||
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
@@ -897,82 +880,10 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
cache_dir = global_cache_dir("hub")
|
||||
config = CLIPVisionConfig.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
|
||||
key
|
||||
]
|
||||
|
||||
# load clip vision
|
||||
model.model.load_state_dict(text_model_dict)
|
||||
|
||||
# load mapper
|
||||
keys_mapper = {
|
||||
k[len("cond_stage_model.mapper.res") :]: v
|
||||
for k, v in checkpoint.items()
|
||||
if k.startswith("cond_stage_model.mapper")
|
||||
}
|
||||
|
||||
MAPPING = {
|
||||
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
||||
"attn.c_proj": ["attn1.to_out.0"],
|
||||
"ln_1": ["norm1"],
|
||||
"ln_2": ["norm3"],
|
||||
"mlp.c_fc": ["ff.net.0.proj"],
|
||||
"mlp.c_proj": ["ff.net.2"],
|
||||
}
|
||||
|
||||
mapped_weights = {}
|
||||
for key, value in keys_mapper.items():
|
||||
prefix = key[: len("blocks.i")]
|
||||
suffix = key.split(prefix)[-1].split(".")[-1]
|
||||
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
||||
mapped_names = MAPPING[name]
|
||||
|
||||
num_splits = len(mapped_names)
|
||||
for i, mapped_name in enumerate(mapped_names):
|
||||
new_name = ".".join([prefix, mapped_name, suffix])
|
||||
shape = value.shape[0] // num_splits
|
||||
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
||||
|
||||
model.mapper.load_state_dict(mapped_weights)
|
||||
|
||||
# load final layer norm
|
||||
model.final_layer_norm.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
||||
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load final proj
|
||||
model.proj_out.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["proj_out.bias"],
|
||||
"weight": checkpoint["proj_out.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load uncond vector
|
||||
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
||||
return model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
cache_dir = global_cache_dir("hub")
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||
subfolder='text_encoder',
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -1048,22 +959,30 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
|
||||
new_key = f'first_stage_model.{vae_key}'
|
||||
checkpoint[new_key] = state_dict[vae_key]
|
||||
|
||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
|
||||
vae_config = create_vae_diffusers_config(
|
||||
vae_config, image_size=image_size
|
||||
)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
checkpoint, vae_config
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
num_in_channels: int = None,
|
||||
scheduler_type: str = "pndm",
|
||||
pipeline_type: str = None,
|
||||
image_size: int = None,
|
||||
prediction_type: str = None,
|
||||
model_version: BaseModelType,
|
||||
model_variant: ModelVariantType,
|
||||
original_config_file: str,
|
||||
extract_ema: bool = True,
|
||||
upcast_attn: bool = False,
|
||||
vae: AutoencoderKL = None,
|
||||
vae_path: str = None,
|
||||
precision: torch.dtype = torch.float32,
|
||||
return_generator_pipeline: bool = False,
|
||||
scan_needed:bool=True,
|
||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||
upcast_attention: bool = False,
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
|
||||
scan_needed: bool = True,
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@@ -1075,149 +994,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
:param checkpoint_path: Path to `.ckpt` file.
|
||||
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
||||
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
||||
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
||||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
|
||||
inferred.
|
||||
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
|
||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
|
||||
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
||||
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||
running stable diffusion 2.1.
|
||||
:param vae: A diffusers VAE to load into the pipeline.
|
||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
if Path(checkpoint_path).suffix == '.ckpt':
|
||||
if scan_needed:
|
||||
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
||||
cache_dir = global_cache_dir("hub")
|
||||
pipeline_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if return_generator_pipeline
|
||||
else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
if scan_needed:
|
||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
upcast_attention = False
|
||||
if original_config_file is None:
|
||||
model_type = ModelManager.probe_model_type(checkpoint)
|
||||
|
||||
if model_type == SDLegacyType.V2_v:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
|
||||
)
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
original_config_file = (
|
||||
global_config_dir()
|
||||
/ "stable-diffusion"
|
||||
/ "v1-inpainting-inference.yaml"
|
||||
)
|
||||
|
||||
elif model_type == SDLegacyType.V1:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Unknown checkpoint type")
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"][
|
||||
"in_channels"
|
||||
] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
# as it relies on a brittle global step parameter here
|
||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||
if image_size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
image_size = 512 if global_step == 875000 else 768
|
||||
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
|
||||
image_size = 768
|
||||
else:
|
||||
if prediction_type is None:
|
||||
prediction_type = "epsilon"
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
image_size = 512
|
||||
|
||||
#
|
||||
# convert scheduler
|
||||
#
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
scheduler = PNDMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
skip_prk_steps=True
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
if scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
#
|
||||
# convert unet
|
||||
#
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
@@ -1230,44 +1068,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
logger.debug(f"Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
logger.debug("Using checkpoint model's original VAE")
|
||||
#
|
||||
# convert vae
|
||||
#
|
||||
|
||||
if vae:
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
checkpoint, vae_config
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae = convert_ldm_vae_to_diffusers(
|
||||
checkpoint,
|
||||
original_config,
|
||||
image_size,
|
||||
)
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
if model_type is None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(
|
||||
"."
|
||||
)[-1]
|
||||
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2",
|
||||
subfolder="tokenizer",
|
||||
cache_dir=cache_dir,
|
||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||
subfolder='tokenizer',
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
@@ -1277,49 +1096,26 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||
)
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker",
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=None if return_generator_pipeline else safety_checker.to(precision),
|
||||
safety_checker=safety_checker.to(precision),
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained(
|
||||
"bert-base-uncased", cache_dir=cache_dir
|
||||
)
|
||||
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ROOT / "bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(
|
||||
vqvae=vae,
|
||||
bert=text_model,
|
||||
@@ -1333,15 +1129,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
checkpoint_path: Union[str, Path],
|
||||
dump_path: Union[str, Path],
|
||||
**kwargs,
|
||||
checkpoint_path: Union[str, Path],
|
||||
dump_path: Union[str, Path],
|
||||
model_root: Union[str, Path],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Takes all the arguments of load_pipeline_from_original_stable_diffusion_ckpt(),
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
# setting global here to avoid massive changes late at night
|
||||
global MODEL_ROOT
|
||||
MODEL_ROOT = Path(model_root) / 'core/convert'
|
||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||
|
||||
pipe.save_pretrained(
|
||||
|
||||
678
invokeai/backend/model_management/lora.py
Normal file
678
invokeai/backend/model_management/lora.py
Normal file
@@ -0,0 +1,678 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
|
||||
class LoRALayerBase:
|
||||
#rank: Optional[int]
|
||||
#alpha: Optional[float]
|
||||
#bias: Optional[torch.Tensor]
|
||||
#layer_key: str
|
||||
|
||||
#@property
|
||||
#def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if (
|
||||
"bias_indices" in values
|
||||
and "bias_values" in values
|
||||
and "bias_size" in values
|
||||
):
|
||||
self.bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||
multiplier: float,
|
||||
):
|
||||
if type(module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight(module)
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
return op(
|
||||
*input_h,
|
||||
(weight + bias).view(module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
) * multiplier * scale
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
#up: torch.Tensor
|
||||
#mid: Optional[torch.Tensor]
|
||||
#down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
#w1_a: torch.Tensor
|
||||
#w1_b: torch.Tensor
|
||||
#w2_a: torch.Tensor
|
||||
#w2_b: torch.Tensor
|
||||
#t1: Optional[torch.Tensor] = None
|
||||
#t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(module_key, rank, alpha, bias)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1 = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2 = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
||||
)
|
||||
rebuild2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
||||
)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
#w1: Optional[torch.Tensor] = None
|
||||
#w1_a: Optional[torch.Tensor] = None
|
||||
#w1_b: Optional[torch.Tensor] = None
|
||||
#w2: Optional[torch.Tensor] = None
|
||||
#w2_a: Optional[torch.Tensor] = None
|
||||
#w2_b: Optional[torch.Tensor] = None
|
||||
#t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(module_key, rank, alpha, bias)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2 = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2 = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoRAModel: #(torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
_device: torch.device
|
||||
_dtype: torch.dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self._name = name
|
||||
self._device = device or torch.cpu
|
||||
self._dtype = dtype or torch.float32
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> LoRAModel:
|
||||
# TODO: try revert if exception?
|
||||
for key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
name=file_path.stem, # TODO:
|
||||
layers=dict(),
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {layer_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = dict()
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix):].split('_')
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = module_key.rstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@staticmethod
|
||||
def _lora_forward_hook(
|
||||
applied_loras: List[Tuple[LoraModel, float]],
|
||||
layer_name: str,
|
||||
):
|
||||
|
||||
def lora_forward(module, input_h, output):
|
||||
if len(applied_loras) == 0:
|
||||
return output
|
||||
|
||||
for lora, weight in applied_loras:
|
||||
layer = lora.layers.get(layer_name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output += layer.forward(module, input_h, weight)
|
||||
return output
|
||||
|
||||
return lora_forward
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
hooks = dict()
|
||||
try:
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
if module_key not in hooks:
|
||||
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
finally:
|
||||
for module_key, hook in hooks.items():
|
||||
hook.remove()
|
||||
hooks.clear()
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
ti_list: List[Any],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i]
|
||||
trigger = _get_trigger(ti, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
model_embeddings.weight.data[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
if init_tokens_count and new_tokens_added:
|
||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||
|
||||
|
||||
class TextualInversionModel:
|
||||
name: str
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
result.name = file_path.stem # TODO:
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
# both v1 and v2 format embeddings
|
||||
# difference mostly in metadata
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
|
||||
# v3 (easynegative)
|
||||
elif "emb_params" in state_dict:
|
||||
result.embedding = state_dict["emb_params"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = dict()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, token_ids: list[int]
|
||||
) -> list[int]:
|
||||
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
return new_token_ids
|
||||
|
||||
391
invokeai/backend/model_management/model_cache.py
Normal file
391
invokeai/backend/model_management/model_cache.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_models_cached=6)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, types, Optional, Type, Any
|
||||
|
||||
import torch
|
||||
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from .lora import LoRAModel, TextualInversionModel
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.cache = cache
|
||||
self._locks = 0
|
||||
|
||||
def lock(self):
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self):
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
return self._locks > 0
|
||||
|
||||
@property
|
||||
def loaded(self):
|
||||
if self.model is not None and hasattr(self.model, "device"):
|
||||
return self.model.device != self.cache.storage_device
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
sequential_offload: bool=False,
|
||||
lazy_offloading: bool=True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger
|
||||
):
|
||||
'''
|
||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
#max_cache_size = 9999
|
||||
execution_device = torch.device('cuda')
|
||||
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
#self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_cache_size: int=max_cache_size
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
self.logger = logger
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
#def get_model(
|
||||
# self,
|
||||
# repo_id_or_path: Union[str, Path],
|
||||
# model_type: ModelType = ModelType.Diffusers,
|
||||
# subfolder: Path = None,
|
||||
# submodel: ModelType = None,
|
||||
# revision: str = None,
|
||||
# attach_model_part: Tuple[ModelType, str] = (None, None),
|
||||
# gpu_load: bool = True,
|
||||
#) -> ModelLocker: # ?? what does it return
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
return self.model_infos[model_info_key]
|
||||
|
||||
# TODO: args
|
||||
def get_model(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
gpu_load: bool = True,
|
||||
) -> Any:
|
||||
|
||||
if not isinstance(model_path, Path):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
model_info = self._get_model_info(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
|
||||
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
self._make_cache_room(model_info.get_size(submodel))
|
||||
|
||||
# clean memory to make MemoryUsage() more accurate
|
||||
gc.collect()
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
if mem_used := model_info.get_size(submodel):
|
||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||
|
||||
cache_entry = _CacheRecord(self, model, mem_used)
|
||||
self._cached_models[key] = cache_entry
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
if not hasattr(self.model, 'to'):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this
|
||||
# code to move it into GPU!
|
||||
if self.gpu_load:
|
||||
self.cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
|
||||
if self.model.device != self.cache.execution_device:
|
||||
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
||||
with VRAMUsage() as mem:
|
||||
self.model.to(self.cache.execution_device) # move into GPU
|
||||
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||
|
||||
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except:
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
|
||||
# TODO: not fully understand
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
||||
self.model.to(self.cache.storage_device)
|
||||
|
||||
return self.model
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if not hasattr(self.model, 'to'):
|
||||
return
|
||||
|
||||
self.cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
# TODO: should it be called untrack_model?
|
||||
def uncache_model(self, cache_id: str):
|
||||
with suppress(ValueError):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
) -> str:
|
||||
'''
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
'''
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"Return the current size of the cache, in GB"
|
||||
current_cache_size = sum([m.size for m in self._cached_models.values()])
|
||||
return current_cache_size / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % self.cache_size()
|
||||
|
||||
cached_models = 0
|
||||
loaded_models = 0
|
||||
locked_models = 0
|
||||
for model_info in self._cached_models.values():
|
||||
cached_models += 1
|
||||
if model_info.loaded:
|
||||
loaded_models += 1
|
||||
if model_info.locked:
|
||||
locked_models += 1
|
||||
|
||||
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
|
||||
|
||||
|
||||
def _make_cache_room(self, model_size):
|
||||
# calculate how much memory this model will require
|
||||
#multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
||||
|
||||
# 2 refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
if not cache_entry.locked and refs <= 2:
|
||||
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
||||
current_size -= cache_entry.size
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||
cache_entry.model.to(self.storage_device)
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(model_path)
|
||||
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
self.logger.debug(f'computing hash of model {path.name}')
|
||||
for file in list(path.rglob("*.ckpt")) \
|
||||
+ list(path.rglob("*.safetensors")) \
|
||||
+ list(path.rglob("*.pth")):
|
||||
with open(file, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
class VRAMUsage(object):
|
||||
def __init__(self):
|
||||
self.vram = None
|
||||
self.vram_used = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.vram = torch.cuda.memory_allocated()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.vram_used = torch.cuda.memory_allocated() - self.vram
|
||||
118
invokeai/backend/model_management/model_install.py
Normal file
118
invokeai/backend/model_management/model_install.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Routines for downloading and installing models.
|
||||
"""
|
||||
import json
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from diffusers import ModelMixin
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from . import ModelManager
|
||||
from .models import BaseModelType, ModelType, VariantType
|
||||
from .model_probe import ModelProbe, ModelVariantInfo
|
||||
from .model_cache import SilenceWarnings
|
||||
|
||||
class ModelInstall(object):
|
||||
'''
|
||||
This class is able to download and install several different kinds of
|
||||
InvokeAI models. The helper function, if provided, is called on to distinguish
|
||||
between v2-base and v2-768 stable diffusion pipelines. This usually involves
|
||||
asking the user to select the proper type, as there is no way of distinguishing
|
||||
the two type of v2 file programmatically (as far as I know).
|
||||
'''
|
||||
def __init__(self,
|
||||
config: InvokeAIAppConfig,
|
||||
model_base_helper: Callable[[Path],BaseModelType]=None,
|
||||
clobber:bool = False
|
||||
):
|
||||
'''
|
||||
:param config: InvokeAI configuration object
|
||||
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
|
||||
:param clobber: If true, models with colliding names will be overwritten
|
||||
'''
|
||||
self.config = config
|
||||
self.clogger = clobber
|
||||
self.helper = model_base_helper
|
||||
self.prober = ModelProbe()
|
||||
|
||||
def install_checkpoint_file(self, checkpoint: Path)->dict:
|
||||
'''
|
||||
Install the checkpoint file at path and return a
|
||||
configuration entry that can be added to `models.yaml`.
|
||||
Model checkpoints and VAEs will be converted into
|
||||
diffusers before installation. Note that the model manager
|
||||
does not hold entries for anything but diffusers pipelines,
|
||||
and the configuration file stanzas returned from such models
|
||||
can be safely ignored.
|
||||
'''
|
||||
model_info = self.prober.probe(checkpoint, self.helper)
|
||||
if not model_info:
|
||||
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
||||
|
||||
key = ModelManager.create_key(
|
||||
model_name = checkpoint.stem,
|
||||
base_model = model_info.base_type,
|
||||
model_type = model_info.model_type,
|
||||
)
|
||||
destination_path = self._dest_path(model_info) / checkpoint
|
||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._check_for_collision(destination_path)
|
||||
stanza = {
|
||||
key: dict(
|
||||
name = checkpoint.stem,
|
||||
description = f'{model_info.model_type} model {checkpoint.stem}',
|
||||
base = model_info.base_model.value,
|
||||
type = model_info.model_type.value,
|
||||
variant = model_info.variant_type.value,
|
||||
path = str(destination_path),
|
||||
)
|
||||
}
|
||||
|
||||
# non-pipeline; no conversion needed, just copy into right place
|
||||
if model_info.model_type != ModelType.Pipeline:
|
||||
shutil.copyfile(checkpoint, destination_path)
|
||||
stanza[key].update({'format': 'checkpoint'})
|
||||
|
||||
# pipeline - conversion needed here
|
||||
else:
|
||||
destination_path = self._dest_path(model_info) / checkpoint.stem
|
||||
config_file = self._pipeline_type_to_config_file(model_info.model_type)
|
||||
|
||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings:
|
||||
convert_ckpt_to_diffusers(
|
||||
checkpoint,
|
||||
destination_path,
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
scan_needed=False,
|
||||
)
|
||||
stanza[key].update({'format': 'folder',
|
||||
'path': destination_path, # no suffix on this
|
||||
})
|
||||
|
||||
return stanza
|
||||
|
||||
|
||||
def _check_for_collision(self, path: Path):
|
||||
if not path.exists():
|
||||
return
|
||||
if self.clobber:
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
|
||||
|
||||
def _staging_directory(self)->tempfile.TemporaryDirectory:
|
||||
return tempfile.TemporaryDirectory(dir=self.config.root_path)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user