mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 12:08:23 -05:00
Compare commits
829 Commits
feat/contr
...
feat/api/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60a2fbec41 | ||
|
|
f15a328b80 | ||
|
|
811d9ab55a | ||
|
|
e00fed5c46 | ||
|
|
a3fa38b353 | ||
|
|
2e42a4bdd9 | ||
|
|
36f72b5a49 | ||
|
|
af42d7d347 | ||
|
|
8607b1994c | ||
|
|
e051c450ed | ||
|
|
50135b726e | ||
|
|
c647056287 | ||
|
|
30f20b55d5 | ||
|
|
1bca32ed16 | ||
|
|
7f91139e21 | ||
|
|
c53b7c7389 | ||
|
|
93f3658a4a | ||
|
|
68be95acbb | ||
|
|
813f79f0f9 | ||
|
|
c3ec86bc70 | ||
|
|
05a19753c6 | ||
|
|
a33327c651 | ||
|
|
6ad7cc4f2a | ||
|
|
c506355b8b | ||
|
|
d54168b8fb | ||
|
|
c91b071c47 | ||
|
|
9c57b18008 | ||
|
|
69539a0472 | ||
|
|
3f45294c61 | ||
|
|
fd03c7eebe | ||
|
|
07c49a5726 | ||
|
|
8c688f8e29 | ||
|
|
3d13167d32 | ||
|
|
f2bb507ebb | ||
|
|
fe8f3381fc | ||
|
|
2a6d11e645 | ||
|
|
01f46d3c7d | ||
|
|
5f76b62553 | ||
|
|
4bbe3b0d00 | ||
|
|
9ed86a08f1 | ||
|
|
0a50e2638c | ||
|
|
fc7c5da4dd | ||
|
|
a3357e073c | ||
|
|
d114833a12 | ||
|
|
96038bd075 | ||
|
|
2f383c2598 | ||
|
|
702a8d1f72 | ||
|
|
0a8390356f | ||
|
|
844058c0a5 | ||
|
|
7d74cbe29c | ||
|
|
62ac0ed2dc | ||
|
|
ae14adec2a | ||
|
|
6c2b39d1df | ||
|
|
0843028e6e | ||
|
|
de0fd87035 | ||
|
|
8b6c0be259 | ||
|
|
58fec84858 | ||
|
|
f223ad7776 | ||
|
|
00eabf630d | ||
|
|
6245a27650 | ||
|
|
fa1ac57c90 | ||
|
|
0f16b1c98d | ||
|
|
08e66c5451 | ||
|
|
563bf70c95 | ||
|
|
49d29420c4 | ||
|
|
ae9d0c6c1b | ||
|
|
d8d11f9bbb | ||
|
|
13fa0d3bc0 | ||
|
|
5eeb4b8e06 | ||
|
|
f5044c290d | ||
|
|
1b43276e5d | ||
|
|
294f086857 | ||
|
|
e5024bf5e9 | ||
|
|
79198b4bba | ||
|
|
1a2f0984db | ||
|
|
454683e6eb | ||
|
|
bbb2a08e8f | ||
|
|
bf116927e1 | ||
|
|
3d249c4fa3 | ||
|
|
fa338ddb6a | ||
|
|
b200451330 | ||
|
|
8283d23b74 | ||
|
|
2fc0a4d53b | ||
|
|
3ff732d583 | ||
|
|
840c632c0a | ||
|
|
40d6e4f287 | ||
|
|
fc5f9c30a6 | ||
|
|
229de2dbb8 | ||
|
|
cc22427f25 | ||
|
|
90333c0074 | ||
|
|
54e5301b35 | ||
|
|
b31fc43bfa | ||
|
|
9bcf0b2251 | ||
|
|
d4bc98c383 | ||
|
|
bc892c535c | ||
|
|
099e1e7c08 | ||
|
|
b1000e30c1 | ||
|
|
7bd94eac0e | ||
|
|
2c77563dcc | ||
|
|
603c9a587e | ||
|
|
1a5a2dfda9 | ||
|
|
090b7eeaf3 | ||
|
|
117536324c | ||
|
|
999c092b6a | ||
|
|
9e31b1f387 | ||
|
|
cb157ea530 | ||
|
|
5f6f38074d | ||
|
|
25b8dd340a | ||
|
|
fb06f5b892 | ||
|
|
1a7fb601dc | ||
|
|
cdcfda164d | ||
|
|
966b154a1f | ||
|
|
95fa66661c | ||
|
|
6247b79111 | ||
|
|
5831364f9c | ||
|
|
919b81cff1 | ||
|
|
065fff7db5 | ||
|
|
a664ee30a2 | ||
|
|
03f3ad435a | ||
|
|
2270c270ef | ||
|
|
4f7820719b | ||
|
|
fa285883ad | ||
|
|
474fca8e6a | ||
|
|
5dc0250b00 | ||
|
|
f269377a01 | ||
|
|
d0406024e3 | ||
|
|
aa3a969bd2 | ||
|
|
73a95973a8 | ||
|
|
bf4fe3c1ac | ||
|
|
d6c08ba469 | ||
|
|
69f0ba65f1 | ||
|
|
828c86964d | ||
|
|
54b7ddd63f | ||
|
|
a0dde66b5d | ||
|
|
b6b3b9f99c | ||
|
|
faa69f8a47 | ||
|
|
d92c7f5483 | ||
|
|
6b824eb112 | ||
|
|
72b4371804 | ||
|
|
fa290aff8d | ||
|
|
3d99d7ae8b | ||
|
|
2eb367969c | ||
|
|
9cdad95f48 | ||
|
|
707ed39300 | ||
|
|
6bbb5f061a | ||
|
|
6896e69e95 | ||
|
|
b17f4c1650 | ||
|
|
98493ed9e2 | ||
|
|
94c953deab | ||
|
|
fa4d88e163 | ||
|
|
b1e1e3efc7 | ||
|
|
3b9426eb72 | ||
|
|
e2e07696fc | ||
|
|
d6a959b000 | ||
|
|
c3935d3849 | ||
|
|
383e3d77cb | ||
|
|
31e97ead2a | ||
|
|
0b49995659 | ||
|
|
ff204db6b2 | ||
|
|
f74f3d6a3a | ||
|
|
713fb061e8 | ||
|
|
77b7680b32 | ||
|
|
ff63433591 | ||
|
|
31281d7181 | ||
|
|
72d1e4e404 | ||
|
|
91918e648b | ||
|
|
1390b65a9c | ||
|
|
82231369d3 | ||
|
|
7620bacc01 | ||
|
|
ea9cf04765 | ||
|
|
47301e6f85 | ||
|
|
f143fb7254 | ||
|
|
2bdb655375 | ||
|
|
41f7758977 | ||
|
|
8ae1eaaccc | ||
|
|
d66979073b | ||
|
|
c9e621093e | ||
|
|
e06ba40795 | ||
|
|
6571e4c2fd | ||
|
|
ff9240b51d | ||
|
|
18466e01fd | ||
|
|
e9821ab711 | ||
|
|
d6530df635 | ||
|
|
062b2cf46f | ||
|
|
082ecf6f25 | ||
|
|
1632ac6b9f | ||
|
|
877959b413 | ||
|
|
6e60f7517b | ||
|
|
296ee6b7ea | ||
|
|
7c7ffddb2b | ||
|
|
e1ae7842ff | ||
|
|
9687fe7bac | ||
|
|
a9a2bd90c2 | ||
|
|
47ca71a7eb | ||
|
|
a9c47237b1 | ||
|
|
33bbae2f47 | ||
|
|
fab7a1d337 | ||
|
|
cffcf80977 | ||
|
|
1a3fd05b81 | ||
|
|
c22c6ca135 | ||
|
|
3afb6a387f | ||
|
|
33e5ed7180 | ||
|
|
2067757fab | ||
|
|
b1b94a3d56 | ||
|
|
c9ee42450e | ||
|
|
10fe31c2a1 | ||
|
|
dc54cbb1fc | ||
|
|
070218aba7 | ||
|
|
f1c226b171 | ||
|
|
7004430380 | ||
|
|
1ddc620192 | ||
|
|
a7cebbd970 | ||
|
|
d97438b0b3 | ||
|
|
4522f3f4c9 | ||
|
|
6fe28980b0 | ||
|
|
4aec5d8ffc | ||
|
|
bbb4e8f5ef | ||
|
|
bce33ea62e | ||
|
|
e4705d5ce7 | ||
|
|
6764b2a854 | ||
|
|
970340cf62 | ||
|
|
043f9d9ba4 | ||
|
|
6f82801d07 | ||
|
|
3e3dd39ae4 | ||
|
|
89aa06e014 | ||
|
|
6cc00ef4b7 | ||
|
|
f31e62afad | ||
|
|
38fd2ad45d | ||
|
|
05b99b5377 | ||
|
|
08a14ee6d5 | ||
|
|
29fcc92da9 | ||
|
|
d78e3572e3 | ||
|
|
160267c71a | ||
|
|
fd47e70c92 | ||
|
|
9317b42e5f | ||
|
|
bdab73701f | ||
|
|
3ea5e78322 | ||
|
|
f609ee21a2 | ||
|
|
f51defeeb3 | ||
|
|
ee0225f4ba | ||
|
|
33a0af4637 | ||
|
|
d37b08a7dd | ||
|
|
9a796364da | ||
|
|
1ad4eb3a7b | ||
|
|
3767a453bb | ||
|
|
b0892d30a4 | ||
|
|
d9b1e4a98c | ||
|
|
a4dec8c1d6 | ||
|
|
8960ceb98b | ||
|
|
be79d088c0 | ||
|
|
009407ea3f | ||
|
|
6999d28c7f | ||
|
|
324e9eb74b | ||
|
|
56cff40362 | ||
|
|
2ba40c5e52 | ||
|
|
3ab147204c | ||
|
|
e4c89cba9c | ||
|
|
322ea84c4e | ||
|
|
f2b41c60ff | ||
|
|
754acec92f | ||
|
|
11fc7e40a5 | ||
|
|
d15bb88eb2 | ||
|
|
70ba36eefc | ||
|
|
7e70391c2b | ||
|
|
e2a94be336 | ||
|
|
63a86eefb4 | ||
|
|
b0727b9d47 | ||
|
|
d96e727dd5 | ||
|
|
fe480886dc | ||
|
|
8031d1827b | ||
|
|
b5acdb322d | ||
|
|
a4d1fe8819 | ||
|
|
10b7a58887 | ||
|
|
901a277959 | ||
|
|
aaa093bef1 | ||
|
|
bb96543d66 | ||
|
|
a2a2cfa765 | ||
|
|
18e6a2b410 | ||
|
|
db27263bc2 | ||
|
|
0e027ec3ef | ||
|
|
5acbbeecaa | ||
|
|
6ef2168b67 | ||
|
|
6d958a214c | ||
|
|
4ae4bf4ff9 | ||
|
|
fdef53b2de | ||
|
|
11bd038b9d | ||
|
|
768cfe3aab | ||
|
|
c4277b0662 | ||
|
|
020f3ccf07 | ||
|
|
7467fa5e57 | ||
|
|
e19ef7ed2f | ||
|
|
71003be6b8 | ||
|
|
c1dbafc2df | ||
|
|
dcebd71381 | ||
|
|
d855a65e73 | ||
|
|
a9007c7e0f | ||
|
|
af60304f97 | ||
|
|
6de241eead | ||
|
|
51032dc0b2 | ||
|
|
9ec3d2bc0c | ||
|
|
297931f5d9 | ||
|
|
f613c073c1 | ||
|
|
63d248622c | ||
|
|
48485fe92f | ||
|
|
07726af703 | ||
|
|
ad1004b485 | ||
|
|
0096fb2790 | ||
|
|
9c8c2e49d6 | ||
|
|
2005a96847 | ||
|
|
00a8d60c1b | ||
|
|
3aa182390a | ||
|
|
e44f1d6d4e | ||
|
|
dfdf8e2ead | ||
|
|
3a645c4e80 | ||
|
|
113129daf9 | ||
|
|
940e3b6635 | ||
|
|
7fb29dabff | ||
|
|
714ad6dbb8 | ||
|
|
c0863fa20f | ||
|
|
78b0b37ba6 | ||
|
|
5d5cdc7716 | ||
|
|
93cd818f6a | ||
|
|
598a628790 | ||
|
|
f3666eda63 | ||
|
|
754017b59e | ||
|
|
21251ce12c | ||
|
|
dc12fa6cd6 | ||
|
|
f2f4c37f19 | ||
|
|
0864fca641 | ||
|
|
5e4c0217c7 | ||
|
|
78cd106c23 | ||
|
|
6ed0efa938 | ||
|
|
ca0669c337 | ||
|
|
b59a749627 | ||
|
|
a91dee87d0 | ||
|
|
5ff98a4179 | ||
|
|
36b2f12219 | ||
|
|
5569f205ee | ||
|
|
a76cf8aab2 | ||
|
|
5c0f0d1808 | ||
|
|
951900a86a | ||
|
|
582f516fef | ||
|
|
a25bae2545 | ||
|
|
0ea35b1e3d | ||
|
|
c6f935bf1a | ||
|
|
96b4d35d43 | ||
|
|
7b0938e7e4 | ||
|
|
249522b568 | ||
|
|
39088e42cc | ||
|
|
30e0033ebe | ||
|
|
b599c40099 | ||
|
|
8f190169db | ||
|
|
1d4d705795 | ||
|
|
b3f71b3078 | ||
|
|
6059db4f15 | ||
|
|
0d5f44b153 | ||
|
|
17164a37a8 | ||
|
|
f88ccabe30 | ||
|
|
e1c85f1234 | ||
|
|
57a3eb3652 | ||
|
|
82a8972bde | ||
|
|
497a885c85 | ||
|
|
4d9f55d0f6 | ||
|
|
0c3b4bb70d | ||
|
|
33e13820fc | ||
|
|
43d991cfdb | ||
|
|
291e9cf14b | ||
|
|
a2de5c9963 | ||
|
|
5025f84627 | ||
|
|
d2c8a53c55 | ||
|
|
5659d10778 | ||
|
|
46cab81d6f | ||
|
|
dd157bce85 | ||
|
|
2f25dd7d0d | ||
|
|
e56965ad76 | ||
|
|
2273b3a8c8 | ||
|
|
05fb0ac2b2 | ||
|
|
d4acd49ee3 | ||
|
|
d98868e524 | ||
|
|
93bb27f2c7 | ||
|
|
a4c44edf8d | ||
|
|
1e94d7739a | ||
|
|
9110838fe4 | ||
|
|
ca7b267326 | ||
|
|
7f5992d6a5 | ||
|
|
88776fb2de | ||
|
|
34f567abd4 | ||
|
|
b87f3043ae | ||
|
|
3829ffbe66 | ||
|
|
ad619ae880 | ||
|
|
d22ebe08be | ||
|
|
ee0c6ad86e | ||
|
|
96adb56633 | ||
|
|
3000436121 | ||
|
|
37cdd91f5d | ||
|
|
6f3c6ddf3f | ||
|
|
0bfbda512d | ||
|
|
295b98a13c | ||
|
|
ff6b345d45 | ||
|
|
1fb307abf4 | ||
|
|
29c952dcf6 | ||
|
|
010f63a50d | ||
|
|
068bbe3a39 | ||
|
|
ad39680feb | ||
|
|
1e0ae8404c | ||
|
|
460d555a3d | ||
|
|
66ad04fcfc | ||
|
|
c7c0836721 | ||
|
|
d2c223de8f | ||
|
|
dd16f788ed | ||
|
|
b25c1af018 | ||
|
|
8f393b64b8 | ||
|
|
55b3193629 | ||
|
|
6f78c073ed | ||
|
|
c406be6f4f | ||
|
|
aeaf3737aa | ||
|
|
23d9d58c08 | ||
|
|
4c331a5d7e | ||
|
|
035425ef24 | ||
|
|
021e5a2aa3 | ||
|
|
7a1de3887e | ||
|
|
4a7a5234df | ||
|
|
6aebe1614d | ||
|
|
74292eba28 | ||
|
|
c31ff364ab | ||
|
|
f310a39381 | ||
|
|
5a7e611e0a | ||
|
|
4e29a751d8 | ||
|
|
3f94f81acd | ||
|
|
5de3c41d19 | ||
|
|
f071b03ceb | ||
|
|
b9375186a5 | ||
|
|
11bd932cba | ||
|
|
b77ccfaf32 | ||
|
|
96653eebb6 | ||
|
|
60d25f105f | ||
|
|
734b653a5f | ||
|
|
52c9e6ec91 | ||
|
|
c0f132e41a | ||
|
|
cc1160a43a | ||
|
|
adde8450bc | ||
|
|
5bf9891553 | ||
|
|
22c34c343a | ||
|
|
f7804f6126 | ||
|
|
d14b02e93f | ||
|
|
1b75d899ae | ||
|
|
d4aa79acd7 | ||
|
|
33d199c007 | ||
|
|
9c89d3452c | ||
|
|
fb0b63c580 | ||
|
|
bb2c6e5925 | ||
|
|
928caff2a6 | ||
|
|
670c79f2c7 | ||
|
|
d6efb98953 | ||
|
|
19da795274 | ||
|
|
454ba9b893 | ||
|
|
d2dc1ed26f | ||
|
|
d4fb16825e | ||
|
|
650d69ef5b | ||
|
|
ff0e79fa9a | ||
|
|
127b54f812 | ||
|
|
7025c00581 | ||
|
|
7ea995149e | ||
|
|
f9710dd6ed | ||
|
|
4e7dd7d3f6 | ||
|
|
20ca9e1fc1 | ||
|
|
8a8b09a953 | ||
|
|
9e4e386c9b | ||
|
|
eca1e449a8 | ||
|
|
ffaadb9d05 | ||
|
|
8adff96e29 | ||
|
|
7593dc19d6 | ||
|
|
b7c5a39685 | ||
|
|
bd1b84f7d0 | ||
|
|
eadfd239a8 | ||
|
|
8d75e50435 | ||
|
|
1d9c115225 | ||
|
|
30af20a056 | ||
|
|
cc21fb216c | ||
|
|
6fe62a2705 | ||
|
|
da87378713 | ||
|
|
b6f5267385 | ||
|
|
f9e78d3c64 | ||
|
|
b7b5bd1b46 | ||
|
|
9a3727d3ad | ||
|
|
d68c14516c | ||
|
|
9f4d39aa42 | ||
|
|
84b801d88f | ||
|
|
2fc70c509b | ||
|
|
34fb1c4b19 | ||
|
|
80bdd550cf | ||
|
|
7ef0d2aa35 | ||
|
|
2359b92b46 | ||
|
|
a404fb2d32 | ||
|
|
513eb11616 | ||
|
|
d2c9140e69 | ||
|
|
d95fe5925a | ||
|
|
835922ea8f | ||
|
|
e1e5266fc3 | ||
|
|
5e4457445f | ||
|
|
0221ca8f49 | ||
|
|
cf36e4029e | ||
|
|
c8a98a9a22 | ||
|
|
38ecca9362 | ||
|
|
c4681774a5 | ||
|
|
050add58d2 | ||
|
|
3d60c958c7 | ||
|
|
f5df150097 | ||
|
|
dac82adb5b | ||
|
|
b72c9787a9 | ||
|
|
2623941d91 | ||
|
|
d3a7fea939 | ||
|
|
5a7b687c84 | ||
|
|
0020457fc7 | ||
|
|
658b556544 | ||
|
|
37da0fc075 | ||
|
|
6d3e8507cc | ||
|
|
0e9470503f | ||
|
|
d2ebc6741b | ||
|
|
026d3260b4 | ||
|
|
1103ab2844 | ||
|
|
11b2076b46 | ||
|
|
78533714e3 | ||
|
|
691e1bf829 | ||
|
|
47a088d685 | ||
|
|
63db3fc22f | ||
|
|
ad0bb3f61a | ||
|
|
8f8cd90787 | ||
|
|
d796ea7bec | ||
|
|
e5b7dd63e9 | ||
|
|
af060188bd | ||
|
|
4270e7ae25 | ||
|
|
60a565d7de | ||
|
|
78cf70eaad | ||
|
|
eebaa50710 | ||
|
|
7d582553f2 | ||
|
|
4d6eea7e81 | ||
|
|
f44593331d | ||
|
|
3d9ecbf3c7 | ||
|
|
032aa1d59c | ||
|
|
35e0863bdb | ||
|
|
14070d674e | ||
|
|
108ce06c62 | ||
|
|
da364f3444 | ||
|
|
df5ba75c14 | ||
|
|
e4fb9cb33f | ||
|
|
65b527eb20 | ||
|
|
7dc9d18052 | ||
|
|
5013a4b9f3 | ||
|
|
f929359322 | ||
|
|
6522c71971 | ||
|
|
9c1e65f3a3 | ||
|
|
ebec200ba6 | ||
|
|
e559730b6e | ||
|
|
0acb8ed85d | ||
|
|
8c1c9cd702 | ||
|
|
0ece4686aa | ||
|
|
af95cef7f9 | ||
|
|
1eca7a918a | ||
|
|
9e6b958023 | ||
|
|
f7b99d93ae | ||
|
|
85d03dcd90 | ||
|
|
032555bcfe | ||
|
|
4caa1f19b2 | ||
|
|
95d4bd3012 | ||
|
|
037078c8ad | ||
|
|
6de2f66b50 | ||
|
|
cd7b248eda | ||
|
|
6d8c077f4e | ||
|
|
97127e560e | ||
|
|
27dc07d95a | ||
|
|
f7dc171c4f | ||
|
|
4b957edfec | ||
|
|
46ca7718d9 | ||
|
|
b928d7a6e6 | ||
|
|
8a836247c8 | ||
|
|
95c3644564 | ||
|
|
799cd07174 | ||
|
|
9af385468d | ||
|
|
3487388788 | ||
|
|
9a383e456d | ||
|
|
805f9f8f4a | ||
|
|
52aa0c9bbd | ||
|
|
7f5f4689cc | ||
|
|
a3f81f4b98 | ||
|
|
15c59e606f | ||
|
|
40d4cabecd | ||
|
|
3493c8119b | ||
|
|
c1e7460d39 | ||
|
|
3ffff023b2 | ||
|
|
f9384be59b | ||
|
|
6cf308004a | ||
|
|
d1029138d2 | ||
|
|
06b5800d28 | ||
|
|
483f2ccb56 | ||
|
|
93ced0bec6 | ||
|
|
4333852c37 | ||
|
|
3baa230077 | ||
|
|
9e594f9018 | ||
|
|
b0c41b4828 | ||
|
|
e0d6946b6b | ||
|
|
bf7ea8309f | ||
|
|
54b65f725f | ||
|
|
8ef49c2640 | ||
|
|
f488b1a7f2 | ||
|
|
d2edb7c402 | ||
|
|
f0a3f07b45 | ||
|
|
b42b630583 | ||
|
|
31a78d571b | ||
|
|
fdc2232ea0 | ||
|
|
e94d0b2d40 | ||
|
|
75ccbaee9c | ||
|
|
2848c8397c | ||
|
|
fe8b5193de | ||
|
|
3d1470399c | ||
|
|
fcf9c63049 | ||
|
|
7bfb5640ad | ||
|
|
15e57e3a3d | ||
|
|
279468c0e8 | ||
|
|
c565812723 | ||
|
|
ec6c8e2a38 | ||
|
|
77f2690711 | ||
|
|
c4b3a24ed7 | ||
|
|
33c69359c2 | ||
|
|
864f4bb4af | ||
|
|
5365f42a04 | ||
|
|
3dc60254b9 | ||
|
|
027a8562d7 | ||
|
|
34f3a0f0e3 | ||
|
|
d0bac1675e | ||
|
|
4e56c962f4 | ||
|
|
4ef0e43759 | ||
|
|
6945d10297 | ||
|
|
4d6cef7ac8 | ||
|
|
a7786d5ff2 | ||
|
|
6c1de975d9 | ||
|
|
a1079e455a | ||
|
|
5457c7f069 | ||
|
|
b8c1a3f96c | ||
|
|
cee8e85f76 | ||
|
|
09f166577e | ||
|
|
bcc21531fb | ||
|
|
da4eacdffe | ||
|
|
6102e560ba | ||
|
|
ff3aa57117 | ||
|
|
49db6f4fac | ||
|
|
20f6a597ab | ||
|
|
04c453721c | ||
|
|
350ffecc1f | ||
|
|
b0557aa16b | ||
|
|
1c9429a6ea | ||
|
|
206e6b1730 | ||
|
|
357cee2849 | ||
|
|
0b49997bb6 | ||
|
|
5e09dd380d | ||
|
|
c7303adb0d | ||
|
|
ed1f096a6f | ||
|
|
6ab5d28cf3 | ||
|
|
a75148cb16 | ||
|
|
f7bbc4004a | ||
|
|
cee21ca082 | ||
|
|
08ec12b391 | ||
|
|
ff5e2a9a8c | ||
|
|
e0b9b5cc6c | ||
|
|
aca4770481 | ||
|
|
5d5157fc65 | ||
|
|
fb6ef61a4d | ||
|
|
ee24ad7b13 | ||
|
|
f8e90ba3f0 | ||
|
|
ad0b70ca23 | ||
|
|
7dfa135b2c | ||
|
|
beeaa05658 | ||
|
|
6b6d654f60 | ||
|
|
853c83d0c2 | ||
|
|
1809990ed4 | ||
|
|
79d49853d2 | ||
|
|
1f608d3743 | ||
|
|
df024dd982 | ||
|
|
45da85765c | ||
|
|
bd0ad59c27 | ||
|
|
cce40acba5 | ||
|
|
bc9491ab69 | ||
|
|
f28632980d | ||
|
|
b909bac0dc | ||
|
|
8618e41b32 | ||
|
|
4687f94141 | ||
|
|
440912dcff | ||
|
|
8b87a26e7e | ||
|
|
44ae93df3e | ||
|
|
42d938fda5 | ||
|
|
8f80ba9520 | ||
|
|
25ce47c44f | ||
|
|
afd2e32092 | ||
|
|
2b213da967 | ||
|
|
e91e1eb9aa | ||
|
|
b24129fb3e | ||
|
|
350b1421bb | ||
|
|
f01c79a94f | ||
|
|
463f6352ce | ||
|
|
a80fe05e23 | ||
|
|
58d7833c5c | ||
|
|
5012f61599 | ||
|
|
85c33823c3 | ||
|
|
c83a112669 | ||
|
|
e04ada1319 | ||
|
|
d866dcb3d2 | ||
|
|
81ec476f3a | ||
|
|
1e6adf0a06 | ||
|
|
7d221e2518 | ||
|
|
742ed19d66 | ||
|
|
29c2ada23c | ||
|
|
e4196bbe5b | ||
|
|
15ffb53e59 | ||
|
|
90054ddf0d | ||
|
|
56d3cbead0 | ||
|
|
5e8c97f1ba | ||
|
|
4687ad4ed6 | ||
|
|
994b247f8e | ||
|
|
0419f50ab0 | ||
|
|
f9f40adcdc | ||
|
|
3264d30b44 | ||
|
|
4d885653e9 | ||
|
|
475b6bef53 | ||
|
|
d39de0ad38 | ||
|
|
d14a7d756e | ||
|
|
b050c1bb8f | ||
|
|
276dfc591b | ||
|
|
b49d76ebee | ||
|
|
a6be44789b | ||
|
|
a4313c26cb | ||
|
|
d4b250d509 | ||
|
|
29743a9e02 | ||
|
|
fecb77e344 | ||
|
|
779671753d | ||
|
|
d5e152b35e | ||
|
|
270657a62c | ||
|
|
3601b9c860 | ||
|
|
c8fe12cd91 | ||
|
|
deae5fbaec | ||
|
|
5b558af2b3 | ||
|
|
4150d5306f | ||
|
|
8c2e4700f9 | ||
|
|
adaecada20 | ||
|
|
258895bcc9 | ||
|
|
2eb7c25bae | ||
|
|
2e4e9434c1 | ||
|
|
0cad204e74 | ||
|
|
0bc2edc044 | ||
|
|
16488e7db8 | ||
|
|
974841926d | ||
|
|
8db20e0d95 | ||
|
|
d00d29d6b5 | ||
|
|
dc976cd665 | ||
|
|
6d6b986a66 | ||
|
|
bffdede0fa | ||
|
|
a4c258e9ec | ||
|
|
8d837558ac | ||
|
|
e673ed08ec | ||
|
|
f0e07bff5a | ||
|
|
3ec06a1fc3 | ||
|
|
6b79e2b407 | ||
|
|
0eed9dbc44 | ||
|
|
53c7832fd1 | ||
|
|
ca1cc0e2c2 | ||
|
|
5d8728c7ef | ||
|
|
a8cec4c7e6 | ||
|
|
2b5ccdc55f | ||
|
|
d92d5b5258 | ||
|
|
a591184d2a | ||
|
|
ee881e4c78 | ||
|
|
61fbb24e36 | ||
|
|
d582949488 | ||
|
|
de574eb4d9 | ||
|
|
bfd90968f1 | ||
|
|
4a924c9b54 | ||
|
|
0453d60c64 | ||
|
|
c4f4f8b1b8 | ||
|
|
3e80eaa342 | ||
|
|
00a0cb3403 | ||
|
|
ea93cad5ff | ||
|
|
4453a0d20d | ||
|
|
1e837e3c9d | ||
|
|
0f95f7cea3 | ||
|
|
0b0068ab86 | ||
|
|
31c7fa833e | ||
|
|
db16ca0079 | ||
|
|
a824f47bc6 | ||
|
|
99392debe8 | ||
|
|
0cc739afc8 | ||
|
|
0ab62b0343 | ||
|
|
75d25dd5cc | ||
|
|
2e54da13d8 | ||
|
|
f34f416bf5 | ||
|
|
021c63891d | ||
|
|
a968862e6b | ||
|
|
a08189d457 | ||
|
|
0a936696c3 | ||
|
|
55e33eaf4c | ||
|
|
3da5fb223f | ||
|
|
a3c5a664e5 | ||
|
|
b638fb2f30 | ||
|
|
c1b10b2222 | ||
|
|
bee29714d9 | ||
|
|
d40d5276dd | ||
|
|
568f0aad71 | ||
|
|
38474fa9d4 | ||
|
|
f7f974a28b | ||
|
|
3c150b384c | ||
|
|
65816049ba | ||
|
|
c1c881ded5 | ||
|
|
82c4dd8b86 | ||
|
|
711d09a107 | ||
|
|
74013b6611 | ||
|
|
790f399986 | ||
|
|
73cdd36594 | ||
|
|
50ac3eb28d | ||
|
|
d753cff91a | ||
|
|
89f1909e4b | ||
|
|
37916a22ad | ||
|
|
8cb2fa8600 | ||
|
|
8f460b92f1 | ||
|
|
d99a08a441 | ||
|
|
b164330e3c | ||
|
|
0b0e6fe448 | ||
|
|
c132dbdefa | ||
|
|
f3081e7013 | ||
|
|
f904f14f9e | ||
|
|
8917a6d99b | ||
|
|
5a4765046e |
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@@ -2,7 +2,7 @@
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @tildebyte @blessedcoolant
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
@@ -18,17 +18,17 @@
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||
/invokeai/frontend/training @lstein @blessedcoolant
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
|
||||
|
||||
|
||||
|
||||
15
.github/workflows/mkdocs-material.yml
vendored
15
.github/workflows/mkdocs-material.yml
vendored
@@ -2,8 +2,7 @@ name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'refs/heads/v2.3'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -12,6 +11,10 @@ jobs:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
@@ -22,11 +25,15 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install -r docs/requirements-mkdocs.txt
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
@@ -36,7 +43,7 @@ jobs:
|
||||
--verbose
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
||||
19
.github/workflows/test-invoke-pip.yml
vendored
19
.github/workflows/test-invoke-pip.yml
vendored
@@ -80,11 +80,6 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: set test prompt to Pull Request validation
|
||||
if: ${{ github.ref != 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
@@ -105,12 +100,6 @@ jobs:
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
id: run-preload-models
|
||||
env:
|
||||
@@ -129,15 +118,21 @@ jobs:
|
||||
HF_HUB_OFFLINE: 1
|
||||
HF_DATASETS_OFFLINE: 1
|
||||
TRANSFORMERS_OFFLINE: 1
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
run: >
|
||||
invokeai
|
||||
--no-patchmatch
|
||||
--no-nsfw_checker
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
--precision=float32
|
||||
--always_use_cpu
|
||||
--use_memory_db
|
||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
|
||||
- name: Archive results
|
||||
id: archive-results
|
||||
env:
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -201,6 +201,8 @@ checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/web/dist/*
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@
|
||||
|
||||
</div>
|
||||
|
||||
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
|
||||
|
||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||
|
||||
Binary file not shown.
@@ -1,164 +0,0 @@
|
||||
@echo off
|
||||
|
||||
@rem This script will install git (if not found on the PATH variable)
|
||||
@rem using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
@rem For users who already have git, this step will be skipped.
|
||||
|
||||
@rem Next, it'll download the project's source code.
|
||||
@rem Then it will download a self-contained, standalone Python and unpack it.
|
||||
@rem Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
@rem This enables a user to install this project without manually installing git or Python
|
||||
|
||||
@rem change to the script's directory
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set "no_cache_dir=--no-cache-dir"
|
||||
if "%1" == "use-cache" (
|
||||
set "no_cache_dir="
|
||||
)
|
||||
|
||||
echo ***** Installing InvokeAI.. *****
|
||||
@rem Config
|
||||
set INSTALL_ENV_DIR=%cd%\installer_files\env
|
||||
@rem https://mamba.readthedocs.io/en/latest/installation.html
|
||||
set MICROMAMBA_DOWNLOAD_URL=https://github.com/cmdr2/stable-diffusion-ui/releases/download/v1.1/micromamba.exe
|
||||
set RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
set RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
set PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
set PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-x86_64-pc-windows-msvc-shared-install_only.tar.gz
|
||||
|
||||
set PACKAGES_TO_INSTALL=
|
||||
|
||||
call git --version >.tmp1 2>.tmp2
|
||||
if "%ERRORLEVEL%" NEQ "0" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% git
|
||||
|
||||
@rem Cleanup
|
||||
del /q .tmp1 .tmp2
|
||||
|
||||
@rem (if necessary) install git into a contained environment
|
||||
if "%PACKAGES_TO_INSTALL%" NEQ "" (
|
||||
@rem download micromamba
|
||||
echo ***** Downloading micromamba from %MICROMAMBA_DOWNLOAD_URL% to micromamba.exe *****
|
||||
|
||||
call curl -L "%MICROMAMBA_DOWNLOAD_URL%" > micromamba.exe
|
||||
|
||||
@rem test the mamba binary
|
||||
echo ***** Micromamba version: *****
|
||||
call micromamba.exe --version
|
||||
|
||||
@rem create the installer env
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
call micromamba.exe create -y --prefix "%INSTALL_ENV_DIR%"
|
||||
)
|
||||
|
||||
echo ***** Packages to install:%PACKAGES_TO_INSTALL% *****
|
||||
|
||||
call micromamba.exe install -y --prefix "%INSTALL_ENV_DIR%" -c conda-forge %PACKAGES_TO_INSTALL%
|
||||
|
||||
if not exist "%INSTALL_ENV_DIR%" (
|
||||
echo ----- There was a problem while installing "%PACKAGES_TO_INSTALL%" using micromamba. Cannot continue. -----
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
)
|
||||
|
||||
del /q micromamba.exe
|
||||
|
||||
@rem For 'git' only
|
||||
set PATH=%INSTALL_ENV_DIR%\Library\bin;%PATH%
|
||||
|
||||
@rem Download/unpack/clean up InvokeAI release sourceball
|
||||
set err_msg=----- InvokeAI source download failed -----
|
||||
echo Trying to download "%RELEASE_URL%%RELEASE_SOURCEBALL%"
|
||||
curl -L %RELEASE_URL%%RELEASE_SOURCEBALL% --output InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- InvokeAI source unpack failed -----
|
||||
tar -zxf InvokeAI.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q InvokeAI.tgz
|
||||
|
||||
set err_msg=----- InvokeAI source copy failed -----
|
||||
cd InvokeAI-*
|
||||
xcopy . .. /e /h
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
cd ..
|
||||
|
||||
@rem cleanup
|
||||
for /f %%i in ('dir /b InvokeAI-*') do rd /s /q %%i
|
||||
rd /s /q .dev_scripts .github docker-build tests
|
||||
del /q requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo ***** Unpacked InvokeAI source *****
|
||||
|
||||
@rem Download/unpack/clean up python-build-standalone
|
||||
set err_msg=----- Python download failed -----
|
||||
curl -L %PYTHON_BUILD_STANDALONE_URL%/%PYTHON_BUILD_STANDALONE% --output python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- Python unpack failed -----
|
||||
tar -zxf python.tgz
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
del /q python.tgz
|
||||
|
||||
echo ***** Unpacked python-build-standalone *****
|
||||
|
||||
@rem create venv
|
||||
set err_msg=----- problem creating venv -----
|
||||
.\python\python -E -s -m venv .venv
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo ***** Created Python virtual environment *****
|
||||
|
||||
@rem Print venv's Python version
|
||||
set err_msg=----- problem calling venv's python -----
|
||||
echo We're running under
|
||||
.venv\Scripts\python --version
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- pip update failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location --upgrade pip wheel
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Updated pip and wheel *****
|
||||
|
||||
set err_msg=----- requirements file copy failed -----
|
||||
copy binary_installer\py3.10-windows-x86_64-cuda-reqs.txt requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
set err_msg=----- main pip install failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -r requirements.txt
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
echo ***** Installed Python dependencies *****
|
||||
|
||||
set err_msg=----- InvokeAI setup failed -----
|
||||
.venv\Scripts\python -m pip install %no_cache_dir% --no-warn-script-location -e .
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
|
||||
copy binary_installer\invoke.bat.in .\invoke.bat
|
||||
echo ***** Installed invoke launcher script ******
|
||||
|
||||
@rem more cleanup
|
||||
rd /s /q binary_installer installer_files
|
||||
|
||||
@rem preload the models
|
||||
call .venv\Scripts\python ldm\invoke\config\invokeai_configure.py
|
||||
set err_msg=----- model download clone failed -----
|
||||
if %errorlevel% neq 0 goto err_exit
|
||||
deactivate
|
||||
|
||||
echo ***** Finished downloading models *****
|
||||
|
||||
echo All done! Execute the file invoke.bat in this directory to start InvokeAI
|
||||
pause
|
||||
exit
|
||||
|
||||
:err_exit
|
||||
echo %err_msg%
|
||||
pause
|
||||
exit
|
||||
@@ -1,235 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
set -euo pipefail
|
||||
IFS=$'\n\t'
|
||||
|
||||
function _err_exit {
|
||||
if test "$1" -ne 0
|
||||
then
|
||||
echo -e "Error code $1; Error caught was '$2'"
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
fi
|
||||
}
|
||||
|
||||
# This script will install git (if not found on the PATH variable)
|
||||
# using micromamba (an 8mb static-linked single-file binary, conda replacement).
|
||||
# For users who already have git, this step will be skipped.
|
||||
|
||||
# Next, it'll download the project's source code.
|
||||
# Then it will download a self-contained, standalone Python and unpack it.
|
||||
# Finally, it'll create the Python virtual environment and preload the models.
|
||||
|
||||
# This enables a user to install this project without manually installing git or Python
|
||||
|
||||
echo -e "\n***** Installing InvokeAI into $(pwd)... *****\n"
|
||||
|
||||
export no_cache_dir="--no-cache-dir"
|
||||
if [ $# -ge 1 ]; then
|
||||
if [ "$1" = "use-cache" ]; then
|
||||
export no_cache_dir=""
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
OS_NAME=$(uname -s)
|
||||
case "${OS_NAME}" in
|
||||
Linux*) OS_NAME="linux";;
|
||||
Darwin*) OS_NAME="darwin";;
|
||||
*) echo -e "\n----- Unknown OS: $OS_NAME! This script runs only on Linux or macOS -----\n" && exit
|
||||
esac
|
||||
|
||||
OS_ARCH=$(uname -m)
|
||||
case "${OS_ARCH}" in
|
||||
x86_64*) ;;
|
||||
arm64*) ;;
|
||||
*) echo -e "\n----- Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64 -----\n" && exit
|
||||
esac
|
||||
|
||||
# https://mamba.readthedocs.io/en/latest/installation.html
|
||||
MAMBA_OS_NAME=$OS_NAME
|
||||
MAMBA_ARCH=$OS_ARCH
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
MAMBA_OS_NAME="osx"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "linux" ]; then
|
||||
MAMBA_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
if [ "$OS_ARCH" == "x86_64" ]; then
|
||||
MAMBA_ARCH="64"
|
||||
fi
|
||||
|
||||
PY_ARCH=$OS_ARCH
|
||||
if [ "$OS_ARCH" == "arm64" ]; then
|
||||
PY_ARCH="aarch64"
|
||||
fi
|
||||
|
||||
# Compute device ('cd' segment of reqs files) detect goes here
|
||||
# This needs a ton of work
|
||||
# Suggestions:
|
||||
# - lspci
|
||||
# - check $PATH for nvidia-smi, gtt CUDA/GPU version from output
|
||||
# - Surely there's a similar utility for AMD?
|
||||
CD="cuda"
|
||||
if [ "$OS_NAME" == "darwin" ] && [ "$OS_ARCH" == "arm64" ]; then
|
||||
CD="mps"
|
||||
fi
|
||||
|
||||
# config
|
||||
INSTALL_ENV_DIR="$(pwd)/installer_files/env"
|
||||
MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/${MAMBA_OS_NAME}-${MAMBA_ARCH}/latest"
|
||||
RELEASE_URL=https://github.com/invoke-ai/InvokeAI
|
||||
RELEASE_SOURCEBALL=/archive/refs/heads/main.tar.gz
|
||||
PYTHON_BUILD_STANDALONE_URL=https://github.com/indygreg/python-build-standalone/releases/download
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-apple-darwin-install_only.tar.gz
|
||||
elif [ "$OS_NAME" == "linux" ]; then
|
||||
PYTHON_BUILD_STANDALONE=20221002/cpython-3.10.7+20221002-${PY_ARCH}-unknown-linux-gnu-install_only.tar.gz
|
||||
fi
|
||||
echo "INSTALLING $RELEASE_SOURCEBALL FROM $RELEASE_URL"
|
||||
|
||||
PACKAGES_TO_INSTALL=""
|
||||
|
||||
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
|
||||
|
||||
# (if necessary) install git and conda into a contained environment
|
||||
if [ "$PACKAGES_TO_INSTALL" != "" ]; then
|
||||
# download micromamba
|
||||
echo -e "\n***** Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to micromamba *****\n"
|
||||
|
||||
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > micromamba
|
||||
|
||||
chmod u+x ./micromamba
|
||||
|
||||
# test the mamba binary
|
||||
echo -e "\n***** Micromamba version: *****\n"
|
||||
./micromamba --version
|
||||
|
||||
# create the installer env
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
./micromamba create -y --prefix "$INSTALL_ENV_DIR"
|
||||
fi
|
||||
|
||||
echo -e "\n***** Packages to install:$PACKAGES_TO_INSTALL *****\n"
|
||||
|
||||
./micromamba install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge "$PACKAGES_TO_INSTALL"
|
||||
|
||||
if [ ! -e "$INSTALL_ENV_DIR" ]; then
|
||||
echo -e "\n----- There was a problem while initializing micromamba. Cannot continue. -----\n"
|
||||
exit
|
||||
fi
|
||||
fi
|
||||
|
||||
rm -f micromamba.exe
|
||||
|
||||
export PATH="$INSTALL_ENV_DIR/bin:$PATH"
|
||||
|
||||
# Download/unpack/clean up InvokeAI release sourceball
|
||||
_err_msg="\n----- InvokeAI source download failed -----\n"
|
||||
curl -L $RELEASE_URL/$RELEASE_SOURCEBALL --output InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- InvokeAI source unpack failed -----\n"
|
||||
tar -zxf InvokeAI.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f InvokeAI.tgz
|
||||
|
||||
_err_msg="\n----- InvokeAI source copy failed -----\n"
|
||||
cd InvokeAI-*
|
||||
cp -r . ..
|
||||
_err_exit $? _err_msg
|
||||
cd ..
|
||||
|
||||
# cleanup
|
||||
rm -rf InvokeAI-*/
|
||||
rm -rf .dev_scripts/ .github/ docker-build/ tests/ requirements.in requirements-mkdocs.txt shell.nix
|
||||
|
||||
echo -e "\n***** Unpacked InvokeAI source *****\n"
|
||||
|
||||
# Download/unpack/clean up python-build-standalone
|
||||
_err_msg="\n----- Python download failed -----\n"
|
||||
curl -L $PYTHON_BUILD_STANDALONE_URL/$PYTHON_BUILD_STANDALONE --output python.tgz
|
||||
_err_exit $? _err_msg
|
||||
_err_msg="\n----- Python unpack failed -----\n"
|
||||
tar -zxf python.tgz
|
||||
_err_exit $? _err_msg
|
||||
|
||||
rm -f python.tgz
|
||||
|
||||
echo -e "\n***** Unpacked python-build-standalone *****\n"
|
||||
|
||||
# create venv
|
||||
_err_msg="\n----- problem creating venv -----\n"
|
||||
|
||||
if [ "$OS_NAME" == "darwin" ]; then
|
||||
# patch sysconfig so that extensions can build properly
|
||||
# adapted from https://github.com/cashapp/hermit-packages/commit/fcba384663892f4d9cfb35e8639ff7a28166ee43
|
||||
PYTHON_INSTALL_DIR="$(pwd)/python"
|
||||
SYSCONFIG="$(echo python/lib/python*/_sysconfigdata_*.py)"
|
||||
TMPFILE="$(mktemp)"
|
||||
chmod +w "${SYSCONFIG}"
|
||||
cp "${SYSCONFIG}" "${TMPFILE}"
|
||||
sed "s,'/install,'${PYTHON_INSTALL_DIR},g" "${TMPFILE}" > "${SYSCONFIG}"
|
||||
rm -f "${TMPFILE}"
|
||||
fi
|
||||
|
||||
./python/bin/python3 -E -s -m venv .venv
|
||||
_err_exit $? _err_msg
|
||||
source .venv/bin/activate
|
||||
|
||||
echo -e "\n***** Created Python virtual environment *****\n"
|
||||
|
||||
# Print venv's Python version
|
||||
_err_msg="\n----- problem calling venv's python -----\n"
|
||||
echo -e "We're running under"
|
||||
.venv/bin/python3 --version
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- pip update failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location --upgrade pip
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Updated pip *****\n"
|
||||
|
||||
_err_msg="\n----- requirements file copy failed -----\n"
|
||||
cp binary_installer/py3.10-${OS_NAME}-"${OS_ARCH}"-${CD}-reqs.txt requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
_err_msg="\n----- main pip install failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -r requirements.txt
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed Python dependencies *****\n"
|
||||
|
||||
_err_msg="\n----- InvokeAI setup failed -----\n"
|
||||
.venv/bin/python3 -m pip install $no_cache_dir --no-warn-script-location -e .
|
||||
_err_exit $? _err_msg
|
||||
|
||||
echo -e "\n***** Installed InvokeAI *****\n"
|
||||
|
||||
cp binary_installer/invoke.sh.in ./invoke.sh
|
||||
chmod a+rx ./invoke.sh
|
||||
echo -e "\n***** Installed invoke launcher script ******\n"
|
||||
|
||||
# more cleanup
|
||||
rm -rf binary_installer/ installer_files/
|
||||
|
||||
# preload the models
|
||||
.venv/bin/python3 scripts/configure_invokeai.py
|
||||
_err_msg="\n----- model download clone failed -----\n"
|
||||
_err_exit $? _err_msg
|
||||
deactivate
|
||||
|
||||
echo -e "\n***** Finished downloading models *****\n"
|
||||
|
||||
echo "All done! Run the command"
|
||||
echo " $scriptdir/invoke.sh"
|
||||
echo "to start InvokeAI."
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
@@ -1,36 +0,0 @@
|
||||
@echo off
|
||||
|
||||
PUSHD "%~dp0"
|
||||
call .venv\Scripts\activate.bat
|
||||
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo OR
|
||||
echo 3. open the developer console
|
||||
set /p choice="Please enter 1, 2 or 3: "
|
||||
if /i "%choice%" == "1" (
|
||||
echo Starting the InvokeAI command-line.
|
||||
.venv\Scripts\python scripts\invoke.py %*
|
||||
) else if /i "%choice%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI.
|
||||
.venv\Scripts\python scripts\invoke.py --web %*
|
||||
) else if /i "%choice%" == "3" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
echo Python version is:
|
||||
python --version
|
||||
echo *************************
|
||||
echo You are now in the system shell, with the local InvokeAI Python virtual environment activated,
|
||||
echo so that you can troubleshoot this InvokeAI installation as necessary.
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) else (
|
||||
echo Invalid selection
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
deactivate
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
set -eu
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
# set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line"
|
||||
echo "2. browser-based UI"
|
||||
echo "OR"
|
||||
echo "3. open the developer console"
|
||||
echo "Please enter 1, 2, or 3:"
|
||||
read choice
|
||||
|
||||
case $choice in
|
||||
1)
|
||||
printf "\nStarting the InvokeAI command-line..\n";
|
||||
.venv/bin/python scripts/invoke.py $*;
|
||||
;;
|
||||
2)
|
||||
printf "\nStarting the InvokeAI browser-based UI..\n";
|
||||
.venv/bin/python scripts/invoke.py --web $*;
|
||||
;;
|
||||
3)
|
||||
printf "\nDeveloper Console:\n";
|
||||
printf "Python command is:\n\t";
|
||||
which python;
|
||||
printf "Python version is:\n\t";
|
||||
python --version;
|
||||
echo "*************************"
|
||||
echo "You are now in your user shell ($SHELL) with the local InvokeAI Python virtual environment activated,";
|
||||
echo "so that you can troubleshoot this InvokeAI installation as necessary.";
|
||||
printf "*************************\n"
|
||||
echo "*** Type \`exit\` to quit this shell and deactivate the Python virtual environment *** ";
|
||||
/usr/bin/env "$SHELL";
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection";
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,17 +0,0 @@
|
||||
InvokeAI
|
||||
|
||||
Project homepage: https://github.com/invoke-ai/InvokeAI
|
||||
|
||||
Installation on Windows:
|
||||
NOTE: You might need to enable Windows Long Paths. If you're not sure,
|
||||
then you almost certainly need to. Simply double-click the 'WinLongPathsEnabled.reg'
|
||||
file. Note that you will need to have admin privileges in order to
|
||||
do this.
|
||||
|
||||
Please double-click the 'install.bat' file (while keeping it inside the invokeAI folder).
|
||||
|
||||
Installation on Linux and Mac:
|
||||
Please open the terminal, and run './install.sh' (while keeping it inside the invokeAI folder).
|
||||
|
||||
After installation, please run the 'invoke.bat' file (on Windows) or 'invoke.sh'
|
||||
file (on Linux/Mac) to start InvokeAI.
|
||||
@@ -1,33 +0,0 @@
|
||||
--prefer-binary
|
||||
--extra-index-url https://download.pytorch.org/whl/torch_stable.html
|
||||
--extra-index-url https://download.pytorch.org/whl/cu116
|
||||
--trusted-host https://download.pytorch.org
|
||||
accelerate~=0.15
|
||||
albumentations
|
||||
diffusers[torch]~=0.11
|
||||
einops
|
||||
eventlet
|
||||
flask_cors
|
||||
flask_socketio
|
||||
flaskwebgui==1.0.3
|
||||
getpass_asterisk
|
||||
imageio-ffmpeg
|
||||
pyreadline3
|
||||
realesrgan
|
||||
send2trash
|
||||
streamlit
|
||||
taming-transformers-rom1504
|
||||
test-tube
|
||||
torch-fidelity
|
||||
torch==1.12.1 ; platform_system == 'Darwin'
|
||||
torch==1.12.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
torchvision==0.13.1 ; platform_system == 'Darwin'
|
||||
torchvision==0.13.0+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
transformers
|
||||
picklescan
|
||||
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
|
||||
https://github.com/invoke-ai/clipseg/archive/1f754751c85d7d4255fa681f4491ff5711c1c288.zip
|
||||
https://github.com/invoke-ai/GFPGAN/archive/3f5d2397361199bc4a91c08bb7d80f04d7805615.zip ; platform_system=='Windows'
|
||||
https://github.com/invoke-ai/GFPGAN/archive/c796277a1cf77954e5fc0b288d7062d162894248.zip ; platform_system=='Linux' or platform_system=='Darwin'
|
||||
https://github.com/Birch-san/k-diffusion/archive/363386981fee88620709cf8f6f2eea167bd6cd74.zip
|
||||
https://github.com/invoke-ai/PyPatchMatch/archive/129863937a8ab37f6bbcec327c994c0f932abdbc.zip
|
||||
@@ -19,31 +19,56 @@ An invocation looks like this:
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
|
||||
# fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description = "The upscale level")
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Each portion is important to implement correctly.
|
||||
@@ -95,25 +120,67 @@ Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
If the `name` also matches, then the field can be **automatically linked** to a
|
||||
previous invocation by name and matching.
|
||||
|
||||
### Config
|
||||
|
||||
```py
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
This is an optional configuration for the invocation. It inherits from
|
||||
pydantic's model `Config` class, and it used primarily to customize the
|
||||
autogenerated OpenAPI schema.
|
||||
|
||||
The UI relies on the OpenAPI schema in two ways:
|
||||
|
||||
- An API client & Typescript types are generated from it. This happens at build
|
||||
time.
|
||||
- The node editor parses the schema into a template used by the UI to create the
|
||||
node editor UI. This parsing happens at runtime.
|
||||
|
||||
In this example, a `ui` key has been added to the `schema_extra` dict to provide
|
||||
some tags for the UI, to facilitate filtering nodes.
|
||||
|
||||
See the Schema Generation section below for more information.
|
||||
|
||||
### Invoke Function
|
||||
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
results = context.services.generate.upscale_and_reconstruct(
|
||||
image_list = [[image, 0]],
|
||||
upscale = (self.level, self.strength),
|
||||
strength = 0.0, # GFPGAN strength
|
||||
save_original = False,
|
||||
image_callback = None,
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(context.graph_execution_state_id, self.id)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -135,9 +202,16 @@ scenarios. If you need functionality, please provide it as a service in the
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
type: Literal['image'] = 'image'
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
```
|
||||
|
||||
Output classes look like an invocation class without the invoke method. Prefer
|
||||
@@ -168,35 +242,36 @@ Here's that `ImageOutput` class, without the needed schema customisation:
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
```
|
||||
|
||||
The generated OpenAPI schema, and all clients/types generated from it, will have
|
||||
the `type` and `image` properties marked as optional, even though we know they
|
||||
will always have a value by the time we can interact with them via the API.
|
||||
|
||||
Here's the same class, but with the schema customisation added:
|
||||
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
|
||||
`image`, `width` and `height` properties marked as optional, even though we know
|
||||
they will always have a value.
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
# Add schema customization
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
```
|
||||
|
||||
The resultant schema (and any API client or types generated from it) will now
|
||||
have see `type` as string literal `"image"` and `image` as an `ImageField`
|
||||
object.
|
||||
With the customization in place, the schema will now show these properties as
|
||||
required, obviating the need for extensive null checks in client code.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
|
||||
171
docs/features/LOGGING.md
Normal file
171
docs/features/LOGGING.md
Normal file
@@ -0,0 +1,171 @@
|
||||
---
|
||||
title: Controlling Logging
|
||||
---
|
||||
|
||||
# :material-image-off: Controlling Logging
|
||||
|
||||
## Controlling How InvokeAI Logs Status Messages
|
||||
|
||||
InvokeAI logs status messages using a configurable logging system. You
|
||||
can log to the terminal window, to a designated file on the local
|
||||
machine, to the syslog facility on a Linux or Mac, or to a properly
|
||||
configured web server. You can configure several logs at the same
|
||||
time, and control the level of message logged and the logging format
|
||||
(to a limited extent).
|
||||
|
||||
Three command-line options control logging:
|
||||
|
||||
### `--log_handlers <handler1> <handler2> ...`
|
||||
|
||||
This option activates one or more log handlers. Options are "console",
|
||||
"file", "syslog" and "http". To specify more than one, separate them
|
||||
by spaces:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
|
||||
```
|
||||
|
||||
The format of these options is described below.
|
||||
|
||||
### `--log_format {plain|color|legacy|syslog}`
|
||||
|
||||
This controls the format of log messages written to the console. Only
|
||||
the "console" log handler is currently affected by this setting.
|
||||
|
||||
* "plain" provides formatted messages like this:
|
||||
|
||||
```bash
|
||||
|
||||
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
|
||||
```
|
||||
|
||||
* "color" produces similar output, but the text will be color coded to
|
||||
indicate the severity of the message.
|
||||
|
||||
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
|
||||
|
||||
```bash
|
||||
### this is a critical error
|
||||
*** this is an error
|
||||
** this is a warning
|
||||
>> this is an informational messages
|
||||
| this is a debug message
|
||||
```
|
||||
|
||||
* "syslog" produces messages suitable for syslog entries:
|
||||
|
||||
```bash
|
||||
InvokeAI [2691178] <CRITICAL> this is a critical error
|
||||
InvokeAI [2691178] <ERROR> this is an error
|
||||
InvokeAI [2691178] <WARNING> this is a warning
|
||||
InvokeAI [2691178] <INFO> this is an informational messages
|
||||
InvokeAI [2691178] <DEBUG> this is a debug message
|
||||
```
|
||||
|
||||
(note that the date, time and hostname will be added by the syslog
|
||||
system)
|
||||
|
||||
### `--log_level {debug|info|warning|error|critical}`
|
||||
|
||||
Providing this command-line option will cause only messages at the
|
||||
specified level or above to be emitted.
|
||||
|
||||
## Console logging
|
||||
|
||||
When "console" is provided to `--log_handlers`, messages will be
|
||||
written to the command line window in which InvokeAI was launched. By
|
||||
default, the color formatter will be used unless overridden by
|
||||
`--log_format`.
|
||||
|
||||
## File logging
|
||||
|
||||
When "file" is provided to `--log_handlers`, entries will be written
|
||||
to the file indicated in the path argument. By default, the "plain"
|
||||
format will be used:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
## Syslog logging
|
||||
|
||||
When "syslog" is requested, entries will be sent to the syslog
|
||||
system. There are a variety of ways to control where the log message
|
||||
is sent:
|
||||
|
||||
* Send to the local machine using the `/dev/log` socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=/dev/log
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message on a nonstandard
|
||||
port:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost:512
|
||||
```
|
||||
|
||||
* Send to a remote machine named "loghost" on the local LAN using
|
||||
facility LOG_USER and UDP packets:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
|
||||
```
|
||||
|
||||
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
|
||||
are defaults.
|
||||
|
||||
* Send to a remote machine named "loghost" using the facility LOCAL0
|
||||
and using a TCP socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
|
||||
```
|
||||
|
||||
If no arguments are specified (just a bare "syslog"), then the logging
|
||||
system will look for a UNIX socket named `/dev/log`, and if not found
|
||||
try to send a UDP message to `localhost`. The Macintosh OS used to
|
||||
support logging to a socket named `/var/run/syslog`, but this feature
|
||||
has since been disabled.
|
||||
|
||||
## Web logging
|
||||
|
||||
If you have access to a web server that is configured to log messages
|
||||
when a particular URL is requested, you can log using the "http"
|
||||
method:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
|
||||
```
|
||||
|
||||
The optional [,method=] part can be used to specify whether the URL
|
||||
accepts GET (default) or POST messages.
|
||||
|
||||
Currently password authentication and SSL are not supported.
|
||||
|
||||
## Using the configuration file
|
||||
|
||||
You can set and forget logging options by adding a "Logging" section
|
||||
to `invokeai.yaml`:
|
||||
|
||||
```
|
||||
InvokeAI:
|
||||
[... other settings...]
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=/dev/log
|
||||
log_level: info
|
||||
log_format: color
|
||||
```
|
||||
@@ -57,6 +57,9 @@ Personalize models by adding your own style or subjects.
|
||||
## * [The NSFW Checker](NSFW.md)
|
||||
Prevent InvokeAI from displaying unwanted racy images.
|
||||
|
||||
## * [Controlling Logging](LOGGING.md)
|
||||
Control how InvokeAI logs status messages.
|
||||
|
||||
## * [Miscellaneous](OTHER.md)
|
||||
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
||||
batch process a file of prompts, increase the "creativity" of image
|
||||
|
||||
@@ -89,7 +89,7 @@ experimental versions later.
|
||||
sudo apt update
|
||||
sudo apt install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt install python3.10 python3-pip python3.10-venv
|
||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||
```
|
||||
|
||||
|
||||
@@ -216,7 +216,7 @@ manager, please follow these steps:
|
||||
9. Run the command-line- or the web- interface:
|
||||
|
||||
From within INVOKEAI_ROOT, activate the environment
|
||||
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
|
||||
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
|
||||
the script `invokeai`. If the virtual environment you selected is NOT inside
|
||||
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
|
||||
`--root_dir \path\to\invokeai` to the commands below:
|
||||
|
||||
@@ -247,8 +247,8 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
|
||||
@@ -7,42 +7,42 @@ call .venv\Scripts\activate.bat
|
||||
set INVOKEAI_ROOT=.
|
||||
|
||||
:start
|
||||
echo Do you want to generate images using the
|
||||
echo 1. command-line interface
|
||||
echo 2. browser-based UI
|
||||
echo 3. run textual inversion training
|
||||
echo 4. merge models (diffusers type only)
|
||||
echo 5. download and install models
|
||||
echo 6. change InvokeAI startup options
|
||||
echo 7. re-run the configure script to fix a broken install
|
||||
echo 8. open the developer console
|
||||
echo 9. update InvokeAI
|
||||
echo 10. command-line help
|
||||
echo Q - quit
|
||||
set /P restore="Please enter 1-10, Q: [2] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Explore InvokeAI nodes using a command-line interface
|
||||
echo 3. Run textual inversion training
|
||||
echo 4. Merge models (diffusers type only)
|
||||
echo 5. Download and install models
|
||||
echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [2] "
|
||||
if not defined choice set choice=2
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Starting the InvokeAI command-line..
|
||||
python .venv\Scripts\invokeai.exe %*
|
||||
) ELSE IF /I "%restore%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai.exe --web %*
|
||||
) ELSE IF /I "%restore%" == "3" (
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
) ELSE IF /I "%restore%" == "4" (
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
) ELSE IF /I "%restore%" == "5" (
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
echo Running invokeai-model-install...
|
||||
python .venv\Scripts\invokeai-model-install.exe
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
) ELSE IF /I "%choice%" == "6" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%restore%" == "7" (
|
||||
) ELSE IF /I "%choice%" == "7" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --default_only
|
||||
) ELSE IF /I "%restore%" == "8" (
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@@ -54,15 +54,15 @@ IF /I "%restore%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%restore%" == "9" (
|
||||
) ELSE IF /I "%choice%" == "9" (
|
||||
echo Running invokeai-update...
|
||||
python .venv\Scripts\invokeai-update.exe %*
|
||||
) ELSE IF /I "%restore%" == "10" (
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%restore%" == "q" (
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
echo Goodbye!
|
||||
goto ending
|
||||
) ELSE (
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MIT License
|
||||
|
||||
# Coauthored by Lincoln Stein, Eugene Brodsky and Joshua Kimsey
|
||||
# Copyright 2023, The InvokeAI Development Team
|
||||
|
||||
####
|
||||
# This launch script assumes that:
|
||||
# 1. it is located in the runtime directory,
|
||||
@@ -11,85 +16,168 @@
|
||||
|
||||
set -eu
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
export INVOKEAI_ROOT="$scriptdir"
|
||||
PARAMS=$@
|
||||
|
||||
# set required env var for torch on mac MPS
|
||||
# Check to see if dialog is installed (it seems to be fairly standard, but good to check regardless) and if the user has passed the --no-tui argument to disable the dialog TUI
|
||||
tui=true
|
||||
if command -v dialog &>/dev/null; then
|
||||
# This must use $@ to properly loop through the arguments passed by the user
|
||||
for arg in "$@"; do
|
||||
if [ "$arg" == "--no-tui" ]; then
|
||||
tui=false
|
||||
# Remove the --no-tui argument to avoid errors later on when passing arguments to InvokeAI
|
||||
PARAMS=$(echo "$PARAMS" | sed 's/--no-tui//')
|
||||
break
|
||||
fi
|
||||
done
|
||||
else
|
||||
tui=false
|
||||
fi
|
||||
|
||||
# Set required env var for torch on mac MPS
|
||||
if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true
|
||||
do
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line interface"
|
||||
echo "2. browser-based UI"
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. merge models (diffusers type only)"
|
||||
echo "5. download and install models"
|
||||
echo "6. change InvokeAI startup options"
|
||||
echo "7. re-run the configure script to fix a broken install"
|
||||
echo "8. open the developer console"
|
||||
echo "9. update InvokeAI"
|
||||
echo "10. command-line help"
|
||||
echo "Q - Quit"
|
||||
echo ""
|
||||
read -p "Please enter 1-10, Q: [2] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1)
|
||||
echo "Starting the InvokeAI command-line..."
|
||||
invokeai $@
|
||||
;;
|
||||
2)
|
||||
echo "Starting the InvokeAI browser-based UI..."
|
||||
invokeai --web $@
|
||||
;;
|
||||
3)
|
||||
echo "Starting Textual Inversion:"
|
||||
invokeai-ti --gui $@
|
||||
;;
|
||||
4)
|
||||
echo "Merging Models:"
|
||||
invokeai-merge --gui $@
|
||||
;;
|
||||
5)
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
echo "Developer Console:"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
echo "Update:"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
invokeai --help
|
||||
;;
|
||||
[qQ])
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection"
|
||||
exit;;
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
1)
|
||||
clear
|
||||
printf "Generate images with a browser-based interface\n"
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Explore InvokeAI nodes using a command-line interface\n"
|
||||
invokeai $PARAMS
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Textual inversion training\n"
|
||||
invokeai-ti --gui $PARAMS
|
||||
;;
|
||||
4)
|
||||
clear
|
||||
printf "Merge models (diffusers type only)\n"
|
||||
invokeai-merge --gui $PARAMS
|
||||
;;
|
||||
5)
|
||||
clear
|
||||
printf "Download and install models\n"
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
clear
|
||||
printf "Change InvokeAI startup options\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
clear
|
||||
printf "Update InvokeAI\n"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
;;
|
||||
"HELP 1")
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
printf "Exiting...\n"
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
done
|
||||
clear
|
||||
}
|
||||
|
||||
# Dialog-based TUI for launcing Invoke functions
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Generate images using a command-line interface"
|
||||
3 "Textual inversion training"
|
||||
4 "Merge models (diffusers type only)"
|
||||
5 "Download and install models"
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
|
||||
choice=$(dialog --clear \
|
||||
--backtitle "\Zb\Zu\Z3InvokeAI" \
|
||||
--colors \
|
||||
--title "What would you like to do?" \
|
||||
--ok-label "Run" \
|
||||
--cancel-label "Exit" \
|
||||
--help-button \
|
||||
--help-label "CLI Help" \
|
||||
--menu "Select an option:" \
|
||||
0 0 0 \
|
||||
"${options[@]}" \
|
||||
2>&1 >/dev/tty) || clear
|
||||
do_choice "$choice"
|
||||
clear
|
||||
}
|
||||
|
||||
# Command-line interface for launching Invoke functions
|
||||
do_line_input() {
|
||||
clear
|
||||
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Explore InvokeAI nodes using the command-line interface\n"
|
||||
printf "3: Run textual inversion training\n"
|
||||
printf "4: Merge models (diffusers type only)\n"
|
||||
printf "5: Download and install models\n"
|
||||
printf "6: Change InvokeAI startup options\n"
|
||||
printf "7: Re-run the configure script to fix a broken install\n"
|
||||
printf "8: Open the developer console\n"
|
||||
printf "9: Update InvokeAI\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
}
|
||||
|
||||
# Main IF statement for launching Invoke with either the TUI or CLI, and for checking if the user is in the developer console
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true; do
|
||||
if $tui; then
|
||||
# .dialogrc must be located in the same directory as the invoke.sh script
|
||||
export DIALOGRC="./.dialogrc"
|
||||
do_dialog
|
||||
else
|
||||
do_line_input
|
||||
fi
|
||||
done
|
||||
else # in developer console
|
||||
python --version
|
||||
echo "Press ^D to exit"
|
||||
printf "Press ^D to exit\n"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
@@ -38,52 +39,63 @@ def check_internet() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||
logger.info(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
model_manager=get_model_manager(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config, logger),
|
||||
configuration=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
@@ -1,128 +1,250 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from typing import Optional
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a result"""
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a thumbnail"""
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name, is_thumbnail=True
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
"/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The image was uploaded successfully",
|
||||
"model": ImageResponse,
|
||||
},
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
415: {"description": "Image upload failed"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def upload_image(
|
||||
file: UploadFile, request: Request, response: Response
|
||||
) -> ImageResponse:
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(contents))
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
except:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.EXTERNAL,
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
|
||||
ImageType.UPLOAD, filename, img
|
||||
)
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_dto.image_url
|
||||
|
||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||
return image_dto
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_name=filename,
|
||||
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
metadata=ImageResponseMetadata(
|
||||
created=ctime,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
)
|
||||
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
return res
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_origin}/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(
|
||||
image_origin, image_name, image_changes
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the full-resolution image",
|
||||
"content": {"image/png": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_origin: ResourceOrigin = Path(
|
||||
description="The type of full-resolution image file to get"
|
||||
),
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=image_name,
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the image thumbnail",
|
||||
"content": {"image/webp": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_image_thumbnail(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_origin, image_name, thumbnail=True
|
||||
)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path, media_type="image/webp", content_disposition_type="inline"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name
|
||||
)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name, thumbnail=True
|
||||
)
|
||||
return ImageUrlsDTO(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
operation_id="list_images_with_metadata",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(
|
||||
default=ImageType.RESULT, description="The type of images to get"
|
||||
async def list_images_with_metadata(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include"
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
||||
return result
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@@ -112,19 +108,20 @@ async def update_model(
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
print(f">> Checking for model {model_name}...")
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
print(f">> Model not found")
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
@@ -248,4 +245,4 @@ async def delete_model(model_name: str) -> None:
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
# except Exception as e:
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.responses import Response
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
@@ -76,7 +75,7 @@ async def get_session(
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
@@ -99,7 +98,7 @@ async def add_node(
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
@@ -108,9 +107,9 @@ async def add_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
@@ -132,7 +131,7 @@ async def update_node(
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
@@ -141,9 +140,9 @@ async def update_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
@@ -162,7 +161,7 @@ async def delete_node(
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
@@ -171,9 +170,9 @@ async def delete_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
@@ -192,7 +191,7 @@ async def add_edge(
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
@@ -201,9 +200,9 @@ async def add_edge(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@@ -226,7 +225,7 @@ async def delete_edge(
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
edge = Edge(
|
||||
@@ -239,9 +238,9 @@ async def delete_edge(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
@@ -259,14 +258,14 @@ async def invoke_session(
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
) -> None:
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
@@ -281,7 +280,7 @@ async def invoke_session(
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> None:
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@@ -10,13 +11,21 @@ from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
#This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.routers import sessions, models, images
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
# Create the app
|
||||
@@ -33,30 +42,21 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id
|
||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
@@ -74,10 +74,9 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@@ -124,8 +123,7 @@ def custom_openapi():
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
@@ -145,16 +143,20 @@ def overridden_redoc():
|
||||
)
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/",
|
||||
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||
html=True
|
||||
), name="ui"
|
||||
)
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
@@ -229,7 +230,7 @@ class HistoryCommand(BaseCommand):
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
@@ -284,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand):
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
get_subactions = action._get_subactions
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._indent()
|
||||
if isinstance(action, argparse._SubParsersAction):
|
||||
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
||||
yield subaction
|
||||
else:
|
||||
for subaction in get_subactions():
|
||||
yield subaction
|
||||
self._dedent()
|
||||
|
||||
@@ -10,9 +10,11 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
from ...backend import ModelManager, Globals
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
@@ -130,13 +132,13 @@ class Completer(object):
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(model_manager)
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
@@ -152,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(Globals.root, ".invoke_history")
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
@@ -160,8 +162,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
print(
|
||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
@@ -4,32 +4,41 @@ import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
@@ -44,7 +53,6 @@ class CliCommand(BaseModel):
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
@@ -65,7 +73,7 @@ def add_invocation_args(command_parser):
|
||||
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
@@ -182,50 +190,72 @@ def invoke_all(context: CliContext):
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
completer = set_autocompleter(model_manager)
|
||||
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
metadata = CoreMetadataService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@@ -241,10 +271,18 @@ def invoke_cli():
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
set_autocompleter(services)
|
||||
|
||||
while True:
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
try:
|
||||
cmd_input = input("invoke> ")
|
||||
if command_line_args_exist:
|
||||
cmd_input = invocation_commands.pop(0)
|
||||
done = len(invocation_commands) == 0
|
||||
else:
|
||||
cmd_input = input("invoke> ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl-c exits
|
||||
break
|
||||
@@ -365,12 +403,15 @@ def invoke_cli():
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except ValidationError:
|
||||
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@@ -75,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
#fmt: on
|
||||
|
||||
|
||||
@@ -92,6 +96,7 @@ class UIConfig(TypedDict, total=False):
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from pydantic import Field
|
||||
from pydantic import Field, validator
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
@@ -22,9 +22,17 @@ class IntCollectionOutput(BaseInvocationOutput):
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(default=[], description="The float collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
@@ -33,12 +41,34 @@ class RangeInvocation(BaseInvocation):
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
if "start" in values and v <= values["start"]:
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
type: Literal["range_of_size"] = "range_of_size"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.start + self.size, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
@@ -50,11 +80,11 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: Optional[int] = Field(
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=np.iinfo(np.int32).max,
|
||||
description="The seed for the RNG",
|
||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
|
||||
266
invokeai/app/invocations/compel.py
Normal file
266
invokeai/app/invocations/compel.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from ...backend.prompting.conditioning import try_parse_legacy_blend
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment, Conjunction,
|
||||
)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False)
|
||||
if legacy_blend is not None:
|
||||
conjunction = legacy_blend
|
||||
else:
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in blend.prompts
|
||||
]
|
||||
)
|
||||
elif type(prompt) is Conjunction:
|
||||
conjunction: Conjunction = prompt
|
||||
return sum(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in conjunction.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts)>1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(
|
||||
p,
|
||||
tokenizer,
|
||||
display_label_prefix=this_display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
459
invokeai/app/invocations/controlnet_image_processors.py
Normal file
459
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
LineartDetector,
|
||||
LineartAnimeDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
ContentShuffleDetector,
|
||||
ZoeDetector,
|
||||
MediapipeFaceDetector,
|
||||
)
|
||||
|
||||
from .image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: move image processors to separate file (image_analysis.py
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# FIXME: what happened to image metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
width = image_dto.width,
|
||||
# height=processed_image.height,
|
||||
height = image_dto.height,
|
||||
# mode=processed_image.mode,
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies HED edge detection to image"""
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
coarse=self.coarse)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies NormalBae processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies MLSD processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies PIDI processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies content shuffle processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
@@ -7,9 +7,9 @@ import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
@@ -26,24 +26,27 @@ class CvInvocationConfig(BaseModel):
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
"""Simple inpaint using opencv."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
@@ -52,18 +55,20 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=image_inpainted,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_inpainted,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel
|
||||
@@ -10,15 +10,22 @@ import torch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
@@ -46,18 +53,16 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||
# control_strength: Optional[float] = Field(default=1.0, ge=0, le=1, description="The strength of the controlnet")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
@@ -75,13 +80,14 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get(
|
||||
self.control_image.image_type, self.control_image.image_name
|
||||
else context.services.images.get_pil_image(
|
||||
self.control_image.image_origin, self.control_image.image_name
|
||||
)
|
||||
)
|
||||
# loading controlnet model
|
||||
@@ -89,6 +95,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
control_model = None
|
||||
else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
@@ -111,25 +118,22 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image_dto = context.services.images.create(
|
||||
image=generate_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -165,11 +169,13 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
@@ -183,7 +189,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
@@ -194,25 +199,22 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -223,6 +225,39 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(
|
||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||
)
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(
|
||||
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||
)
|
||||
tile_size: int = Field(
|
||||
default=32, ge=1, description="The tile infill method size (px)"
|
||||
)
|
||||
infill_method: INFILL_METHODS = Field(
|
||||
default=DEFAULT_INFILL_METHOD,
|
||||
description="The method used to infill empty regions (px)",
|
||||
)
|
||||
inpaint_width: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the inpaint region (px)",
|
||||
)
|
||||
inpaint_height: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the inpaint region (px)",
|
||||
)
|
||||
inpaint_fill: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The solid infill method color",
|
||||
)
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
@@ -247,14 +282,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
@@ -268,8 +303,8 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
@@ -280,23 +315,20 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
import io
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -30,32 +31,14 @@ class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
||||
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["type", "image", "width", "height", "mode"]
|
||||
}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
image_type: ImageType, image_name: str, image: Image.Image
|
||||
) -> ImageOutput:
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
image_field = ImageField(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
mode=image.mode,
|
||||
)
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
@@ -64,6 +47,8 @@ class MaskOutput(BaseInvocationOutput):
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
@@ -82,16 +67,20 @@ class LoadImageInvocation(BaseInvocation):
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image_type, self.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image_type,
|
||||
image_name=self.image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,32 +90,37 @@ class ShowImageInvocation(BaseInvocation):
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image.image_type,
|
||||
image_name=self.image.image_name,
|
||||
image=image,
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
@@ -134,8 +128,8 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_crop = Image.new(
|
||||
@@ -143,49 +137,53 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image_dto = context.services.images.create(
|
||||
image=image_crop,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
base_image = context.services.images.get_pil_image(
|
||||
self.base_image.image_origin, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@@ -201,20 +199,22 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=new_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=new_image,
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@@ -225,47 +225,169 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=image_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return MaskOutput(
|
||||
mask=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
|
||||
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(
|
||||
self.image1.image_origin, self.image1.image_name
|
||||
)
|
||||
image2 = context.services.images.get_pil_image(
|
||||
self.image2.image_origin, self.image2.image_name
|
||||
)
|
||||
|
||||
multiply_image = ImageChops.multiply(image1, image2)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=multiply_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=channel_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=converted_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
blur = (
|
||||
@@ -275,36 +397,149 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=blur_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
PIL_RESAMPLING_MODES = Literal[
|
||||
"nearest",
|
||||
"box",
|
||||
"bilinear",
|
||||
"hamming",
|
||||
"bicubic",
|
||||
"lanczos",
|
||||
]
|
||||
|
||||
|
||||
PIL_RESAMPLING_MAP = {
|
||||
"nearest": Image.Resampling.NEAREST,
|
||||
"box": Image.Resampling.BOX,
|
||||
"bilinear": Image.Resampling.BILINEAR,
|
||||
"hamming": Image.Resampling.HAMMING,
|
||||
"bicubic": Image.Resampling.BICUBIC,
|
||||
"lanczos": Image.Resampling.LANCZOS,
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
resize_image = image.resize(
|
||||
(self.width, self.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
width = int(image.width * self.scale_factor)
|
||||
height = int(image.height * self.scale_factor)
|
||||
|
||||
resize_image = image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
@@ -312,36 +547,40 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=lerp_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
@@ -354,16 +593,20 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=ilerp_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
245
invokeai/app/invocations/infill.py
Normal file
245
invokeai/app/invocations/infill.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
color: ColorField = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Union[ImageField, None] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -1,29 +1,42 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import einops
|
||||
from typing import Literal, Optional, Union, List
|
||||
|
||||
from compel import Compel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from .controlnet_image_processors import ControlField
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from ..services.image_file_storage import ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .image import ImageField, ImageOutput
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import DiffusionPipeline, ControlNetModel
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
@@ -37,41 +50,55 @@ class LatentsField(BaseModel):
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latent_output"] = "latent_output"
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -102,19 +129,15 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
@@ -125,47 +148,62 @@ class NoiseInvocation(BaseInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||
return v % SEED_MAX
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(choose_torch_device())
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=name)
|
||||
)
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from a prompt."""
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
|
||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -191,37 +229,123 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
# if isinstance(model, DiffusionPipeline):
|
||||
# for component in [model.unet, model.vae]:
|
||||
# configure_model_padding(component,
|
||||
# self.seamless,
|
||||
# self.seamless_axes
|
||||
# )
|
||||
# else:
|
||||
# configure_model_padding(model,
|
||||
# self.seamless,
|
||||
# self.seamless_axes
|
||||
# )
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
def prep_control_data(self,
|
||||
context: InvocationContext,
|
||||
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
if control_input is None:
|
||||
# print("control input is None")
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
# print("control input is empty list")
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
# print("control input is ControlField")
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
# print("control input is list[ControlField]")
|
||||
control_list = control_input
|
||||
else:
|
||||
# print("input control is unrecognized:", type(self.control))
|
||||
control_list = None
|
||||
if (control_list is None):
|
||||
control_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(control_name,
|
||||
subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
||||
control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = model.prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
)
|
||||
control_item = ControlNetData(model=control_model,
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
@@ -234,26 +358,29 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@@ -261,21 +388,23 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
@@ -288,7 +417,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
@@ -296,11 +431,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
@@ -308,6 +439,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
@@ -315,10 +447,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@@ -354,18 +484,143 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
# what happened to metadata?
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# new (post Image service refactor) way of using services to save image
|
||||
# and gnenerate unique image_name
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image
|
||||
)
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, latents)
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -3,8 +3,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
@@ -21,19 +27,30 @@ class MathInvocationConfig(BaseModel):
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
@@ -41,11 +58,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
@@ -53,11 +71,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
@@ -65,11 +84,26 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
|
||||
237
invokeai/app/invocations/param_easing.py
Normal file
237
invokeai/app/invocations/param_easing.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
|
||||
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||||
logger.debug("start_step: " + str(start_step))
|
||||
logger.debug("end_step: " + str(end_step))
|
||||
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
logger.debug("num_presteps: " + str(num_presteps))
|
||||
logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
logger.debug("prelist: " + str(prelist))
|
||||
logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = list()
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
from .math import IntOutput, FloatOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
@@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
@@ -2,21 +2,23 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
@@ -26,8 +28,8 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
@@ -39,18 +41,20 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
@@ -4,22 +4,22 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
@@ -30,8 +30,8 @@ class UpscaleInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
@@ -43,18 +43,20 @@ class UpscaleInvocation(BaseInvocation):
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
@@ -3,12 +3,12 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
print(
|
||||
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||
)
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,29 +1,93 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
try:
|
||||
ImageType(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
|
||||
- INTERNAL: The resource was created by the application.
|
||||
- EXTERNAL: The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
INTERNAL = "internal"
|
||||
"""The resource was created by the application."""
|
||||
EXTERNAL = "external"
|
||||
"""The resource was not created by the application.
|
||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOriginException(ValueError):
|
||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid resource origin."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image.
|
||||
|
||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||
- MASK: The image is a mask image.
|
||||
- CONTROL: The image is a ControlNet control image.
|
||||
- USER: The image is a user-provide image.
|
||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||
"""
|
||||
|
||||
GENERAL = "general"
|
||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||
MASK = "mask"
|
||||
"""MASK: The image is a mask image."""
|
||||
CONTROL = "control"
|
||||
"""CONTROL: The image is a ControlNet control image."""
|
||||
USER = "user"
|
||||
"""USER: The image is a user-provide image."""
|
||||
OTHER = "other"
|
||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||
|
||||
|
||||
class InvalidImageCategoryException(ValueError):
|
||||
"""Raised when a provided value is not a valid ImageCategory.
|
||||
|
||||
Subclasses `ValueError`.
|
||||
"""
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
image_origin: ResourceOrigin = Field(
|
||||
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
93
invokeai/app/models/metadata.py
Normal file
93
invokeai/app/models/metadata.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Optional, Union, List
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""
|
||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||
of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
"""
|
||||
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||
won't add any fields that are not already defined, but other a different metadata service
|
||||
implementation might.
|
||||
"""
|
||||
|
||||
type: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The type of the ancestor node of the image output node.",
|
||||
)
|
||||
"""The type of the ancestor node of the image output node."""
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
"""The positive conditioning"""
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
"""The negative conditioning"""
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/latents in pixels."
|
||||
)
|
||||
"""Width of the image/latents in pixels"""
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/latents in pixels."
|
||||
)
|
||||
"""Height of the image/latents in pixels"""
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
# cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Union[float, list[float]] = Field(
|
||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
"""The number of steps used for inference"""
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
"""The scheduler used for inference"""
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
"""The model used for inference"""
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/latents-to-latents.",
|
||||
)
|
||||
"""The strength used for image-to-image/latents-to-latents."""
|
||||
latents: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial latents."
|
||||
)
|
||||
"""The ID of the initial latents"""
|
||||
vae: Optional[StrictStr] = Field(
|
||||
default=None, description="The VAE used for decoding."
|
||||
)
|
||||
"""The VAE used for decoding"""
|
||||
unet: Optional[StrictStr] = Field(
|
||||
default=None, description="The UNet used dor inference."
|
||||
)
|
||||
"""The UNet used dor inference"""
|
||||
clip: Optional[StrictStr] = Field(
|
||||
default=None, description="The CLIP Encoder used for conditioning."
|
||||
)
|
||||
"""The CLIP Encoder used for conditioning"""
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||
)
|
||||
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
||||
581
invokeai/app/services/config.py
Normal file
581
invokeai/app/services/config.py
Normal file
@@ -0,0 +1,581 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
'''Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||
categories returned by `invokeai --help`. The file looks like this:
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
embedding_dir: embeddings
|
||||
lora_dir: loras
|
||||
autoconvert_dir: null
|
||||
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
Memory/Performance:
|
||||
xformers_enabled: false
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_loaded_models: 4
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
nsfw_checker: true
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 8081
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
OmegaConf dictionary object initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=omegaconf)
|
||||
|
||||
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||
at initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf.parse_args(argv=['--xformers_enabled'])
|
||||
|
||||
It is also possible to set a value at initialization time. However, if
|
||||
you call parse_args() it may be overwritten.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
# False
|
||||
|
||||
|
||||
To avoid this, use `get_config()` to retrieve the application-wide
|
||||
configuration object. This will retain any properties set at object
|
||||
creation time:
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||
conf.parse_args(argv=['--no-xformers'])
|
||||
conf.xformers_enabled
|
||||
# True
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<setting>", as in:
|
||||
|
||||
export INVOKEAI_port=8080
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage at the top level file:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
In most cases, you will want to create a single InvokeAIAppConfig
|
||||
object for the entire application. The InvokeAIAppConfig.get_config() function
|
||||
does this:
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args() # read values from the command line/config file
|
||||
print(config.root)
|
||||
|
||||
# Subclassing
|
||||
|
||||
If you wish to create a similar class, please subclass the
|
||||
`InvokeAISettings` class and define a Literal field named "type",
|
||||
which is set to the desired top-level name. For example, to create a
|
||||
"InvokeBatch" configuration, define like this:
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||
two configs are kept in separate sections of the config file:
|
||||
|
||||
# invokeai.yaml
|
||||
|
||||
InvokeBatch:
|
||||
Resources:
|
||||
node_count: 1
|
||||
cpu_count: 8
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
|
||||
def to_yaml(self)->str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)['type'])[0]
|
||||
field_dict = dict({type:dict()})
|
||||
for name,field in self.__fields__.items():
|
||||
if name in cls._excluded():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self,name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if 'type' in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||
|
||||
initconf = cls.initconf.get(settings_stanza) \
|
||||
if cls.initconf and settings_stanza in cls.initconf \
|
||||
else OmegaConf.create()
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key,value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category","Uncategorized")
|
||||
env_name = env_prefix + '_' + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'Uncategorized'
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
return ['type','initconf']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs='*',
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
|
||||
#fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||
controlnet_dir : Path = Field(default="controlnets", description='Path to directory of ControlNet models.', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
|
||||
#fmt: on
|
||||
|
||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||
'''
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
'''
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
super().parse_args(argv)
|
||||
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||
|
||||
@classmethod
|
||||
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
'''
|
||||
if cls.singleton_config is None \
|
||||
or type(cls.singleton_config)!=cls \
|
||||
or (kwargs and cls.singleton_init != kwargs):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
if self.root:
|
||||
return Path(self.root).expanduser()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
Alias for above.
|
||||
'''
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def init_file_path(self)->Path:
|
||||
'''
|
||||
Path to invokeai.yaml
|
||||
'''
|
||||
return self._resolve(INIT_FILE)
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self)->Path:
|
||||
'''
|
||||
Path to the invokeai.db file.
|
||||
'''
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
Path to models configuration file.
|
||||
'''
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self)->Path:
|
||||
'''
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
'''
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def cache_dir(self)->Path:
|
||||
'''
|
||||
Path to the global cache directory for HuggingFace hub-managed models
|
||||
'''
|
||||
return self.models_dir / "hub"
|
||||
|
||||
@property
|
||||
def models_dir(self)->Path:
|
||||
'''
|
||||
Path to the models directory
|
||||
'''
|
||||
return self._resolve("models")
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
'''
|
||||
Path to the textual inversion embeddings directory.
|
||||
'''
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def lora_path(self)->Path:
|
||||
'''
|
||||
Path to the LoRA models directory.
|
||||
'''
|
||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||
|
||||
@property
|
||||
def controlnet_path(self)->Path:
|
||||
'''
|
||||
Path to the controlnet models directory.
|
||||
'''
|
||||
return self._resolve(self.controlnet_dir) if self.controlnet_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
'''
|
||||
Path to the GFPGAN model.
|
||||
'''
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision=='float32'
|
||||
|
||||
@property
|
||||
def disable_xformers(self)->bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self)->bool:
|
||||
"""Return true if patchmatch true"""
|
||||
return self.patchmatch
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
'''
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
'''
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
@@ -1,4 +1,5 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
@@ -16,38 +17,45 @@ def create_text_to_image() -> LibraryGraph:
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': TextToLatentsInvocation(id='4'),
|
||||
'5': LatentsToImageInvocation(id='5')
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
|
||||
@@ -60,6 +60,33 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
from typing import Optional, Union, List, get_args
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
else:
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
# If the type is a List
|
||||
if get_origin(t) is list:
|
||||
return True
|
||||
|
||||
# If the type is a Union
|
||||
elif t_args:
|
||||
# Check if any of the types in the Union is a List
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if not from_type:
|
||||
@@ -85,7 +112,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
if not issubclass(from_type, to_type):
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
@@ -135,6 +163,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
@@ -162,6 +191,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(
|
||||
@@ -361,7 +391,7 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
@@ -372,41 +402,41 @@ class Graph(BaseModel):
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
@@ -692,7 +722,11 @@ class Graph(BaseModel):
|
||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all((get_origin(f) == list for f in output_fields)):
|
||||
# if not all((get_origin(f) == list for f in output_fields)):
|
||||
# return False
|
||||
|
||||
# Verify that all outputs are lists
|
||||
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
@@ -711,6 +745,13 @@ class Graph(BaseModel):
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||
g = nx.DiGraph()
|
||||
g.add_nodes_from([n for n in self.nodes.items()])
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||
) -> nx.DiGraph:
|
||||
@@ -816,11 +857,9 @@ class GraphExecutionState(BaseModel):
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
|
||||
# TODO: prepare multiple nodes at once?
|
||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
||||
# prepared_id = self._prepare()
|
||||
|
||||
if prepared_id is not None:
|
||||
# Prepare as many nodes as we can
|
||||
while prepared_id is not None:
|
||||
prepared_id = self._prepare()
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
@@ -967,14 +1006,30 @@ class GraphExecutionState(BaseModel):
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
|
||||
# Find next unprepared node where all source nodes are executed
|
||||
# Find next node that:
|
||||
# - was not already prepared
|
||||
# - is not an iterate node whose inputs have not been executed
|
||||
# - does not have an unexecuted iterate ancestor
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node_id = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
# exclude nodes that have already been prepared
|
||||
if n not in self.source_prepared_mapping
|
||||
and all((e[0] in self.executed for e in g.in_edges(n)))
|
||||
# exclude iterate nodes whose inputs have not been executed
|
||||
and not (
|
||||
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
|
||||
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
|
||||
)
|
||||
# exclude nodes who have unexecuted iterate ancestors
|
||||
and not any(
|
||||
(
|
||||
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
|
||||
and a not in self.executed # ...that is not executed
|
||||
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
|
||||
)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -1071,9 +1126,22 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the deepest node that is ready to be executed"""
|
||||
g = self.execution_graph.nx_graph()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if next_node is None:
|
||||
return None
|
||||
|
||||
|
||||
204
invokeai/app/services/image_file_storage.py
Normal file
204
invokeai/app/services/image_file_storage.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||
# 500 internal server error. I don't like having this method on the service.
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_origin in ResourceOrigin:
|
||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
except FileNotFoundError as e:
|
||||
raise ImageFileNotFoundException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", metadata.json())
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_origin, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
419
invokeai/app/services/image_record_storage.py
Normal file
419
invokeai/app/services/image_record_storage.py
Normal file
@@ -0,0 +1,419 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
offset: int = Field(description="Offset from which to retrieve items")
|
||||
limit: int = Field(description="Limit of items to get")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = current_timestamp
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(
|
||||
self, image_origin: ResourceOrigin, image_name: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
category_strings = list(
|
||||
map(lambda c: c.value, set(categories))
|
||||
)
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
images_params.append(limit)
|
||||
images_params.append(offset)
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
height: int,
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -1,238 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import (
|
||||
InvokeAIMetadata,
|
||||
MetadataServiceBase,
|
||||
build_invokeai_metadata_pnginfo,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a paginated list of images."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the path to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> Tuple[str, str, int]:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image path, thumbnail path, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
"""Creates a unique contextual image filename."""
|
||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
__metadata_service: MetadataServiceBase
|
||||
|
||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
self.__metadata_service = metadata_service
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
|
||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||
metadata=ImageResponseMetadata(
|
||||
created=int(os.path.getctime(path)),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = PILImage.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", basename
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
return path
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> Tuple[str, str, int]:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
# TODO: Reading the image and then saving it strips the metadata...
|
||||
if metadata:
|
||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path) # this saved image has an empty info
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
|
||||
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
||||
393
invokeai/app/services/images.py
Normal file
393
invokeai/app/services/images.py
Normal file
@@ -0,0 +1,393 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
InvalidImageCategoryException,
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import (
|
||||
ImageFileDeleteException,
|
||||
ImageFileNotFoundException,
|
||||
ImageFileSaveException,
|
||||
ImageFileStorageBase,
|
||||
)
|
||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.resource_name import NameServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
"""Updates an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image's path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
names: NameServiceBase
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
self.names = names
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
|
||||
image_name = self._services.names.create_image_name()
|
||||
|
||||
metadata = self._get_metadata(session_id, node_id)
|
||||
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
created_at = self._services.records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_origin, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
# Meta fields
|
||||
created_at=created_at,
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem saving image record and file")
|
||||
raise e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, image_origin, changes)
|
||||
return self.get_dto(image_origin, image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_origin, image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_origin, image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_origin, image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_origin, image_name),
|
||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
r.image_origin, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_origin, image_name)
|
||||
self._services.records.delete(image_origin, image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
except ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Union[ImageMetadata, None]:
|
||||
"""Get the metadata for a node."""
|
||||
metadata = None
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
session = self._services.graph_execution_manager.get(session_id)
|
||||
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||
|
||||
return metadata
|
||||
@@ -1,50 +1,60 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.backend import ModelManager
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.restoration_services import RestorationServices
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.config import InvokeAISettings
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .latent_storage import LatentsStorageBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
events: "EventServiceBase"
|
||||
latents: "LatentsStorageBase"
|
||||
queue: "InvocationQueueABC"
|
||||
model_manager: "ModelManager"
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageService"
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
self,
|
||||
model_manager: "ModelManager",
|
||||
events: "EventServiceBase",
|
||||
logger: "Logger",
|
||||
latents: "LatentsStorageBase",
|
||||
images: "ImageService",
|
||||
queue: "InvocationQueueABC",
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: "RestorationServices",
|
||||
configuration: "InvokeAISettings",
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
|
||||
@@ -22,7 +22,8 @@ class Invoker:
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
@@ -49,7 +50,7 @@ class Invoker:
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
@@ -71,18 +72,12 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class LatentsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -47,8 +47,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.set(name, data)
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
@@ -80,7 +80,7 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
|
||||
@@ -1,96 +1,142 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Union
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.image import ImageType, is_image_type
|
||||
|
||||
|
||||
class MetadataImageField(TypedDict):
|
||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||
|
||||
image_type: ImageType
|
||||
image_name: str
|
||||
|
||||
|
||||
class MetadataLatentsField(TypedDict):
|
||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||
|
||||
latents_name: str
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||||
]
|
||||
|
||||
|
||||
class InvokeAIMetadata(TypedDict, total=False):
|
||||
"""InvokeAI-specific metadata format."""
|
||||
|
||||
session_id: Optional[str]
|
||||
node: Optional[NodeMetadata]
|
||||
|
||||
|
||||
def build_invokeai_metadata_pnginfo(
|
||||
metadata: InvokeAIMetadata | None,
|
||||
) -> PngImagePlugin.PngInfo:
|
||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
|
||||
return pnginfo
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
pass
|
||||
"""Handles building metadata for nodes, images, and outputs."""
|
||||
|
||||
@abstractmethod
|
||||
def build_metadata(
|
||||
self, session_id: str, node: BaseModel
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""Builds an ImageMetadata object for a node."""
|
||||
pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
class CoreMetadataService(MetadataServiceBase):
|
||||
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||
"""The ancestor types that contain the core metadata"""
|
||||
|
||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||
"""The core metadata parameters in the ancestor types"""
|
||||
|
||||
try:
|
||||
info = image.info.get("invokeai")
|
||||
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||
"""The core metadata parameters in the noise node"""
|
||||
|
||||
if type(info) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(info)
|
||||
|
||||
if type(loaded_metadata) is not dict:
|
||||
return None
|
||||
|
||||
if len(loaded_metadata.items()) == 0:
|
||||
return None
|
||||
|
||||
return loaded_metadata
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Retrieves an image's metadata as a dict"""
|
||||
loaded_metadata = self._load_metadata(image)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||
def create_image_metadata(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
metadata = self._build_metadata_from_graph(session, node_id)
|
||||
|
||||
return metadata
|
||||
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||
"""
|
||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||
|
||||
Parameters:
|
||||
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||
have the same data as the execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||
"""
|
||||
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
# If the node type is one of the core metadata node types, return its id
|
||||
if node.get("type") in self._ANCESTOR_TYPES:
|
||||
return node.get("id")
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = self._find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
def _get_additional_metadata(
|
||||
self, graph: Graph, node_id: str
|
||||
) -> Union[dict[str, Any], None]:
|
||||
"""
|
||||
Returns additional metadata for a given node.
|
||||
|
||||
Parameters:
|
||||
graph (Graph): The execution graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: A dictionary of additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# Prompt
|
||||
if dest_field == "positive_conditioning":
|
||||
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||
# Negative prompt
|
||||
if dest_field == "negative_conditioning":
|
||||
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||
# Seed, width and height
|
||||
if dest_field == "noise":
|
||||
for field in self._NOISE_FIELDS:
|
||||
metadata[field] = source_node_dict.get(field)
|
||||
return metadata
|
||||
|
||||
def _build_metadata_from_graph(
|
||||
self, session: GraphExecutionState, node_id: str
|
||||
) -> ImageMetadata:
|
||||
"""
|
||||
Builds an ImageMetadata object for a node.
|
||||
|
||||
Parameters:
|
||||
session (GraphExecutionState): The session.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
ImageMetadata: The metadata for the node.
|
||||
"""
|
||||
|
||||
# We need to do all the traversal on the execution graph
|
||||
graph = session.execution_graph
|
||||
|
||||
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
||||
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||
|
||||
# If no ancestor was found, return an empty ImageMetadata object
|
||||
if ancestor_id is None:
|
||||
return ImageMetadata()
|
||||
|
||||
ancestor_node = graph.get_node(ancestor_id)
|
||||
|
||||
# Grab all the core metadata from the ancestor node
|
||||
ancestor_metadata = {
|
||||
param: val
|
||||
for param, val in ancestor_node.dict().items()
|
||||
if param in self._ANCESTOR_PARAMS
|
||||
}
|
||||
|
||||
# Get this image's prompts and noise parameters
|
||||
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
ancestor_metadata.update(addl_metadata)
|
||||
|
||||
return ImageMetadata(**ancestor_metadata)
|
||||
|
||||
@@ -2,26 +2,25 @@ import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from .config import InvokeAISettings
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
)
|
||||
def get_model_manager(config: InvokeAISettings, logger: types.ModuleType) -> ModelManager:
|
||||
model_config = config.model_conf_path
|
||||
if not model_config.exists():
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {model_config} could not be found."), logger
|
||||
)
|
||||
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{config.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@@ -31,20 +30,7 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
embedding_path = config.embedding_path
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
@@ -57,37 +43,36 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.conf),
|
||||
OmegaConf.load(config.model_conf_path),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
embedding_path = embedding_path,
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
if config.autoconvert_path:
|
||||
model_manager.heuristic_import(
|
||||
config.autoconvert_path,
|
||||
)
|
||||
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
@@ -96,13 +81,12 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
|
||||
148
invokeai/app/services/models/image_record.py
Normal file
148
invokeai/app/services/models/image_record.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
"""The category of the image."""
|
||||
width: int = Field(description="The width of the image in px.")
|
||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||
height: int = Field(description="The height of the image in px.")
|
||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(
|
||||
description="The updated timestamp of the image."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the image."
|
||||
)
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The session ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The session ID that generated this image, if it is a generated image."""
|
||||
node_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None,
|
||||
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||
)
|
||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
description="The image's new category."
|
||||
)
|
||||
"""The image's new category."""
|
||||
session_id: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
"""The URL of the image's thumbnail."""
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
width = image_dict.get("width", 0)
|
||||
height = image_dict.get("height", 0)
|
||||
session_id = image_dict.get("session_id", None)
|
||||
node_id = image_dict.get("node_id", None)
|
||||
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
if raw_metadata is not None:
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
session_id=session_id,
|
||||
node_id=node_id,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
@@ -1,17 +1,22 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
@@ -20,7 +25,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: probably better to just not use threads?
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
@@ -29,9 +34,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
@@ -110,7 +122,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
@@ -120,11 +132,22 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
|
||||
30
invokeai/app/services/resource_name.py
Normal file
30
invokeai/app/services/resource_name.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
import uuid
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
|
||||
IMAGE = "image"
|
||||
LATENT = "latent"
|
||||
|
||||
|
||||
class NameServiceBase(ABC):
|
||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
@abstractmethod
|
||||
def create_image_name(self) -> str:
|
||||
"""Creates a name for an image."""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleNameService(NameServiceBase):
|
||||
"""Creates image names from UUIDs."""
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
||||
@@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
@@ -10,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args):
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
@@ -20,20 +21,22 @@ class RestorationServices:
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
@@ -58,15 +61,15 @@ class RestorationServices:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@@ -80,7 +83,7 @@ class RestorationServices:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@@ -93,10 +96,10 @@ class RestorationServices:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
|
||||
@@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
|
||||
34
invokeai/app/services/urls.py
Normal file
34
invokeai/app/services/urls.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return (
|
||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||
)
|
||||
|
||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||
15
invokeai/app/util/metaenum.py
Normal file
15
invokeai/app/util/metaenum.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class MetaEnum(EnumMeta):
|
||||
"""Metaclass to support additional features in Enums.
|
||||
|
||||
- `in` operator support: `'value' in MyEnum -> bool`
|
||||
"""
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
@@ -1,5 +1,21 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
@@ -12,5 +11,3 @@ from .generator import (
|
||||
)
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -25,11 +25,13 @@ from typing import Callable, List, Iterator, Optional, Type
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@@ -70,19 +72,6 @@ class InvokeAIGeneratorOutput:
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
@@ -179,14 +168,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -230,10 +225,10 @@ class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
@@ -359,6 +354,7 @@ class Generator:
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
@@ -376,7 +372,7 @@ class Generator:
|
||||
try:
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
print("** An error occurred while getting initial noise **")
|
||||
logger.error("An error occurred while getting initial noise")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
@@ -611,7 +607,7 @@ class Generator:
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
print(f"** creating directory {dirname}")
|
||||
logger.info(f"creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@@ -72,22 +73,22 @@ class Embiggen(Generator):
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print(
|
||||
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print(
|
||||
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
@@ -97,8 +98,8 @@ class Embiggen(Generator):
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(
|
||||
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
@@ -121,8 +122,8 @@ class Embiggen(Generator):
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
@@ -312,10 +313,10 @@ class Embiggen(Generator):
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
print(
|
||||
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
@@ -361,11 +362,11 @@ class Embiggen(Generator):
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
@@ -547,8 +548,8 @@ class Embiggen(Generator):
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print(
|
||||
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
|
||||
@@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -59,7 +60,7 @@ class Inpaint(Img2Img):
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
@@ -75,18 +76,18 @@ class Inpaint(Img2Img):
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: int = None
|
||||
) -> Image:
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a, *tile_size).copy()
|
||||
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
@@ -127,7 +128,9 @@ class Inpaint(Img2Img):
|
||||
|
||||
return si
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
def mask_edge(
|
||||
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||
) -> Image.Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
@@ -193,7 +196,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
seam_noise = self.get_noise(im.width, im.height)
|
||||
|
||||
result = make_image(seam_noise, seed)
|
||||
result = make_image(seam_noise, seed=None)
|
||||
|
||||
return result
|
||||
|
||||
@@ -206,15 +209,15 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False,
|
||||
@@ -222,7 +225,7 @@ class Inpaint(Img2Img):
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -239,7 +242,7 @@ class Inpaint(Img2Img):
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
if isinstance(init_image, Image.Image):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
@@ -250,8 +253,8 @@ class Inpaint(Img2Img):
|
||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||
)
|
||||
elif infill_method == "solid":
|
||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non-supported infill type {infill_method}", infill_method
|
||||
@@ -269,7 +272,7 @@ class Inpaint(Img2Img):
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
if isinstance(mask_image, Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(
|
||||
mask_image,
|
||||
|
||||
@@ -14,6 +14,8 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@@ -77,8 +79,8 @@ class Txt2Img2Img(Generator):
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.globals defines a small number of global variables that would
|
||||
otherwise have to be passed through long and complex call chains.
|
||||
|
||||
It defines a Namespace object named "Globals" that contains
|
||||
the attributes:
|
||||
|
||||
- root - the root directory under which "models" and "outputs" can be found
|
||||
- initfile - path to the initialization file
|
||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||
- always_use_cpu - force use of CPU even if GPU is available
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
# Where to look for the initialization file and other key components
|
||||
Globals.initfile = "invokeai.init"
|
||||
Globals.models_file = "models.yaml"
|
||||
Globals.models_dir = "models"
|
||||
Globals.config_dir = "configs"
|
||||
Globals.autoscan_dir = "weights"
|
||||
Globals.converted_ckpts_dir = "converted_ckpts"
|
||||
|
||||
# Set the default root directory. This can be overwritten by explicitly
|
||||
# passing the `--root <directory>` argument on the command line.
|
||||
# logic is:
|
||||
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
|
||||
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
|
||||
# 3) use ~/invokeai
|
||||
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
|
||||
):
|
||||
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
|
||||
else:
|
||||
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
|
||||
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
||||
Globals.always_use_cpu = False
|
||||
|
||||
# Whether the internet is reachable for dynamic downloads
|
||||
# The CLI will test connectivity at startup time.
|
||||
Globals.internet_available = True
|
||||
|
||||
# Whether to disable xformers
|
||||
Globals.disable_xformers = False
|
||||
|
||||
# Low-memory tradeoff for guidance calculations.
|
||||
Globals.sequential_guidance = False
|
||||
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
# whether we should convert ckpt files into diffusers models on the fly
|
||||
Globals.ckpt_convert = True
|
||||
|
||||
# logging tokenization everywhere
|
||||
Globals.log_tokenization = False
|
||||
|
||||
|
||||
def global_config_file() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
|
||||
def global_config_dir() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
||||
def global_models_dir() -> Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
|
||||
def global_autoscan_dir() -> Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
|
||||
def global_converted_ckpts_dir() -> Path:
|
||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
|
||||
def global_set_root(root_dir: Union[str, Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
|
||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
"""
|
||||
Returns Path to the model cache directory. If a subdirectory
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for Hugging Face-style conventions. Currently, Hugging Face has
|
||||
moved all models into the "hub" subfolder, so for any pretrained
|
||||
HF model, use:
|
||||
global_cache_dir('hub')
|
||||
|
||||
The legacy location for transformers used to be global_cache_dir('transformers')
|
||||
and global_cache_dir('diffusers') for diffusers.
|
||||
"""
|
||||
home: str = os.getenv("HF_HOME")
|
||||
|
||||
if home is None:
|
||||
home = os.getenv("XDG_CACHE_HOME")
|
||||
|
||||
if home is not None:
|
||||
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
|
||||
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
||||
home += os.sep + "huggingface"
|
||||
|
||||
if home is not None:
|
||||
return Path(home, subdir)
|
||||
else:
|
||||
return Path(Globals.root, "models", subdir)
|
||||
@@ -5,9 +5,9 @@ wraps the actual patchmatch object. It respects the global
|
||||
be suppressed or deferred
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
@@ -24,16 +24,16 @@ class PatchMatch:
|
||||
def _load_patch_match(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
if Globals.try_patchmatch:
|
||||
if config.try_patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
print(">> Patchmatch initialized")
|
||||
logger.info("Patchmatch initialized")
|
||||
else:
|
||||
print(">> Patchmatch not loaded (nonfatal)")
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
print(">> Patchmatch loading disabled")
|
||||
logger.info("Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -30,14 +30,14 @@ work fine.
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
@@ -83,15 +83,15 @@ class Txt2Mask(object):
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", refined=False):
|
||||
print(">> Initializing clipseg model for text to mask inference")
|
||||
logger.info("Initializing clipseg model for text to mask inference")
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -101,18 +101,6 @@ class Txt2Mask(object):
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
"""
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
transforms.Resize(
|
||||
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
||||
), # must be multiple of 64...
|
||||
]
|
||||
)
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
|
||||
@@ -12,17 +12,17 @@ print("Loading Python libraries...\n",file=sys.stderr)
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import HfFolder
|
||||
@@ -35,57 +35,53 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ...frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from ...frontend.install.widgets import (
|
||||
from invokeai.app.services.config import (
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
CyclingForm,
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
)
|
||||
from ..args import PRECISION_CHOICES, Args
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
|
||||
from .model_install_backend import (
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
default_dataset,
|
||||
download_from_hf,
|
||||
hf_download_with_resume,
|
||||
recommended_datasets,
|
||||
UserSelections,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 135
|
||||
MIN_LINES = 45
|
||||
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
# Place frequently-used startup commands here, one or more per line.
|
||||
# Examples:
|
||||
# --outdir=D:\data\images
|
||||
# --no-nsfw_checker
|
||||
# --web --host=0.0.0.0
|
||||
# --steps=20
|
||||
# -Ak_euler_a -C10.0
|
||||
"""
|
||||
|
||||
logger=None
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@@ -96,14 +92,13 @@ If you installed manually from source or with 'pip install': activate the virtua
|
||||
then run one of the following commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
invokeai --web # (connect to http://localhost:9090)
|
||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
||||
invokeai-web
|
||||
|
||||
Command-line interface:
|
||||
Command-line client:
|
||||
invokeai
|
||||
|
||||
If you installed using an installation script, run:
|
||||
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
Add the '--help' argument to see all of the command-line switches available for use.
|
||||
"""
|
||||
@@ -215,16 +210,11 @@ def download_realesrgan():
|
||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
wdn_model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, wdn_model_dest, "RealESRGANwdn")
|
||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
||||
|
||||
|
||||
def download_gfpgan():
|
||||
@@ -243,8 +233,8 @@ def download_gfpgan():
|
||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||
],
|
||||
):
|
||||
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
|
||||
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
||||
model_url, model_dest = model[0], config.root_path / model[1]
|
||||
download_with_progress_bar(model_url, str(model_dest), "GFPGAN weights")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -253,8 +243,8 @@ def download_codeformer():
|
||||
model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
|
||||
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
||||
model_dest = config.root_path / "models/codeformer/codeformer.pth"
|
||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -295,7 +285,7 @@ def download_vaes():
|
||||
# first the diffusers version
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
cache_dir=config.cache_dir,
|
||||
)
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
@@ -306,7 +296,7 @@ def download_vaes():
|
||||
if not hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
||||
model_dir=str(config.root_path / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
except Exception as e:
|
||||
@@ -321,25 +311,24 @@ def get_root(root: str = None) -> str:
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
return str(config.root_path)
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(npyscreen.FormMultiPage):
|
||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
for i in [
|
||||
"Configure startup settings. You can come back and change these later.",
|
||||
"Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.",
|
||||
"Use cursor arrows to make a checkbox selection, and space to toggle.",
|
||||
]:
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@@ -366,7 +355,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=old_opts.outdir or str(default_output_dir()),
|
||||
value=str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
@@ -381,22 +370,21 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.safety_checker = self.add_widget_intelligent(
|
||||
self.nsfw_checker = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="NSFW checker",
|
||||
value=old_opts.safety_checker,
|
||||
value=old_opts.nsfw_checker,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
for i in [
|
||||
"If you have an account at HuggingFace you may paste your access token here",
|
||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
||||
"See https://huggingface.co/settings/tokens",
|
||||
]:
|
||||
label = """If you have an account at HuggingFace you may optionally paste your access token here
|
||||
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
|
||||
"""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
value=line,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
@@ -435,17 +423,10 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.xformers = self.add_widget_intelligent(
|
||||
self.xformers_enabled = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support if available",
|
||||
value=old_opts.xformers,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.ckpt_convert = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Load legacy checkpoint models into memory as diffusers models",
|
||||
value=old_opts.ckpt_convert,
|
||||
value=old_opts.xformers_enabled,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
@@ -480,19 +461,41 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Directory containing embedding/textual inversion files:",
|
||||
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.embedding_path = self.add_widget_intelligent(
|
||||
self.embedding_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
name=" Textual Inversion Embeddings:",
|
||||
value=str(default_embedding_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lora_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" LoRA and LyCORIS:",
|
||||
value=str(default_lora_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.controlnet_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" ControlNets:",
|
||||
value=str(default_controlnet_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -505,11 +508,11 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
for i in [
|
||||
"BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ",
|
||||
"AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT",
|
||||
"https://huggingface.co/spaces/CompVis/stable-diffusion-license",
|
||||
]:
|
||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSE LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
||||
"""
|
||||
for i in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
@@ -548,7 +551,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
@@ -559,9 +562,9 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
)
|
||||
if not Path(opt.embedding_path).parent.exists():
|
||||
if not Path(opt.embedding_dir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:\n"
|
||||
@@ -576,20 +579,24 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
"embedding_dir",
|
||||
"lora_dir",
|
||||
"controlnet_dir",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
# widget library workaround to make max_loaded_models an int rather than a float
|
||||
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||
|
||||
return new_opts
|
||||
|
||||
@@ -608,6 +615,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
"MAIN",
|
||||
editOptsForm,
|
||||
name="InvokeAI Startup Options",
|
||||
cycle_widgets=True,
|
||||
)
|
||||
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
|
||||
self.model_select = self.addForm(
|
||||
@@ -615,6 +623,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=True,
|
||||
cycle_widgets=True,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
@@ -628,18 +637,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = Args().parse_args([])
|
||||
outdir = Path(opts.outdir)
|
||||
if not outdir.is_absolute():
|
||||
opts.outdir = str(Globals.root / opts.outdir)
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
if not init_file.exists():
|
||||
opts.safety_checker = True
|
||||
opts.nsfw_checker = True
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
return Namespace(
|
||||
starter_models=default_dataset()
|
||||
def default_user_selections(program_opts: Namespace) -> UserSelections:
|
||||
return UserSelections(
|
||||
install_models=default_dataset()
|
||||
if program_opts.default_only
|
||||
else recommended_datasets()
|
||||
if program_opts.yes_to_all
|
||||
@@ -647,26 +652,27 @@ def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: str, yes_to_all: bool = False):
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||
|
||||
for name in (
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
"models",
|
||||
"configs",
|
||||
"embeddings",
|
||||
"databases",
|
||||
"loras",
|
||||
"controlnets",
|
||||
"text-inversion-output",
|
||||
"text-inversion-training-data",
|
||||
):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
|
||||
configs_src = Path(configs.__path__[0])
|
||||
configs_dest = Path(root) / "configs"
|
||||
configs_dest = root / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
@@ -677,8 +683,17 @@ def run_console_ui(
|
||||
) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES)
|
||||
# The third argument is needed in the Windows 11 environment to
|
||||
# launch a console window running this program.
|
||||
set_min_terminal_size(MIN_COLS, MIN_LINES,'invokeai-configure')
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
@@ -690,70 +705,66 @@ def run_console_ui(
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
"""
|
||||
Update the invokeai.init file with values from opts Namespace
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
"""
|
||||
# touch file if it doesn't exist
|
||||
if not init_file.exists():
|
||||
with open(init_file, "w") as f:
|
||||
f.write(INIT_FILE_PREAMBLE)
|
||||
|
||||
# We want to write in the changed arguments without clobbering
|
||||
# any other initialization values the user has entered. There is
|
||||
# no good way to do this because of the one-way nature of
|
||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
||||
# initfile needs to be replaced with a fully structured format
|
||||
# such as yaml; this is a hack that will work much of the time
|
||||
args_to_skip = re.compile(
|
||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
||||
)
|
||||
# fix windows paths
|
||||
opts.outdir = opts.outdir.replace("\\", "/")
|
||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
||||
new_file = f"{init_file}.new"
|
||||
try:
|
||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
||||
with open(new_file, "w") as out_file:
|
||||
for line in lines:
|
||||
if len(line) > 0 and not args_to_skip.match(line):
|
||||
out_file.write(line + "\n")
|
||||
out_file.write(
|
||||
f"""
|
||||
--outdir={opts.outdir}
|
||||
--embedding_path={opts.embedding_path}
|
||||
--precision={opts.precision}
|
||||
--max_loaded_models={int(opts.max_loaded_models)}
|
||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
||||
--{'no-' if not opts.xformers else ''}xformers
|
||||
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
|
||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
||||
"""
|
||||
)
|
||||
except OSError as e:
|
||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
||||
|
||||
os.replace(new_file, init_file)
|
||||
|
||||
if opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
# this will load current settings
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key,value in opts.__dict__.items():
|
||||
if hasattr(new_config,key):
|
||||
setattr(new_config,key,value)
|
||||
|
||||
with open(init_file,'w', encoding='utf-8') as file:
|
||||
file.write(new_config.to_yaml())
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return Globals.root / "outputs"
|
||||
|
||||
return config.root_path / "outputs"
|
||||
|
||||
# -------------------------------------
|
||||
def default_embedding_dir() -> Path:
|
||||
return Globals.root / "embeddings"
|
||||
return config.root_path / "embeddings"
|
||||
|
||||
# -------------------------------------
|
||||
def default_lora_dir() -> Path:
|
||||
return config.root_path / "loras"
|
||||
|
||||
# -------------------------------------
|
||||
def default_controlnet_dir() -> Path:
|
||||
return config.root_path / "controlnets"
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
opt.hf_token = HfFolder.get_token()
|
||||
write_opts(opt, initfile)
|
||||
|
||||
# -------------------------------------
|
||||
# Here we bring in
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format:Path):
|
||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
for attr in fields:
|
||||
if hasattr(old,attr):
|
||||
setattr(new,attr,getattr(old,attr))
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
new.nsfw_checker = old.safety_checker
|
||||
new.xformers_enabled = old.xformers
|
||||
new.conf_path = old.conf
|
||||
new.embedding_dir = old.embedding_path
|
||||
|
||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||
outfile.write(new.to_yaml())
|
||||
|
||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
@@ -809,8 +820,13 @@ def main():
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||
invoke_args = []
|
||||
if opt.root:
|
||||
invoke_args.extend(['--root',opt.root])
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
errors = set()
|
||||
|
||||
@@ -818,19 +834,26 @@ def main():
|
||||
models_to_download = default_user_selections(opt)
|
||||
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
init_file = Path(Globals.root, Globals.initfile)
|
||||
if not init_file.exists() or not global_config_file().exists():
|
||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||
old_init_file = config.root_path / 'invokeai.init'
|
||||
new_init_file = config.root_path / 'invokeai.yaml'
|
||||
if old_init_file.exists() and not new_init_file.exists():
|
||||
print('** Migrating invokeai.init to invokeai.yaml')
|
||||
migrate_init_file(old_init_file)
|
||||
# Load new init file into config
|
||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, init_file)
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, init_file)
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
print(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
@@ -840,7 +863,7 @@ def main():
|
||||
if opt.skip_support_models:
|
||||
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
|
||||
else:
|
||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_sd1_clip()
|
||||
download_sd2_clip()
|
||||
@@ -858,6 +881,8 @@ def main():
|
||||
process_and_execute(opt, models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
if not opt.yes_to_all:
|
||||
input('Press any key to continue...')
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
||||
390
invokeai/backend/install/legacy_arg_parsing.py
Normal file
390
invokeai/backend/install/legacy_arg_parsing.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
from argparse import ArgumentParser
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
"auto",
|
||||
"float32",
|
||||
"autocast",
|
||||
"float16",
|
||||
]
|
||||
|
||||
class FileArgumentParser(ArgumentParser):
|
||||
"""
|
||||
Supports reading defaults from an init file.
|
||||
"""
|
||||
def convert_arg_line_to_args(self, arg_line):
|
||||
return shlex.split(arg_line, comments=True)
|
||||
|
||||
|
||||
legacy_parser = FileArgumentParser(
|
||||
description=
|
||||
"""
|
||||
Generate images using Stable Diffusion.
|
||||
Use --web to launch the web interface.
|
||||
Use --from_file to load prompts from a file path or standard input ("-").
|
||||
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
||||
Other command-line arguments are defaults that can usually be overridden
|
||||
prompt the command prompt.
|
||||
""",
|
||||
fromfile_prefix_chars='@',
|
||||
)
|
||||
general_group = legacy_parser.add_argument_group('General')
|
||||
model_group = legacy_parser.add_argument_group('Model selection')
|
||||
file_group = legacy_parser.add_argument_group('Input/output')
|
||||
web_server_group = legacy_parser.add_argument_group('Web server')
|
||||
render_group = legacy_parser.add_argument_group('Rendering')
|
||||
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
||||
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
||||
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
general_group.add_argument(
|
||||
'--version','-V',
|
||||
action='store_true',
|
||||
help='Print InvokeAI version number'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--root_dir',
|
||||
default=None,
|
||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--config',
|
||||
'-c',
|
||||
'-config',
|
||||
dest='conf',
|
||||
default='./configs/models.yaml',
|
||||
help='Path to configuration file for alternate models.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--model',
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--weight_dirs',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0,9),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'-F',
|
||||
'--full_precision',
|
||||
dest='full_precision',
|
||||
action='store_true',
|
||||
help='Deprecated way to set --precision=float32',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--max_loaded_models',
|
||||
dest='max_loaded_models',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--free_gpu_mem',
|
||||
dest='free_gpu_mem',
|
||||
action='store_true',
|
||||
help='Force free gpu memory before final decoding',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--sequential_guidance',
|
||||
dest='sequential_guidance',
|
||||
action='store_true',
|
||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||
"at the expense of speed",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--xformers',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable/disable xformers support (default enabled if installed)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--always_use_cpu",
|
||||
dest="always_use_cpu",
|
||||
action="store_true",
|
||||
help="Force use of CPU even if GPU is available"
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--precision',
|
||||
dest='precision',
|
||||
type=str,
|
||||
choices=PRECISION_CHOICES,
|
||||
metavar='PRECISION',
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default='auto',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--ckpt_convert',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='ckpt_convert',
|
||||
default=True,
|
||||
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--internet',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='internet_available',
|
||||
default=True,
|
||||
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--nsfw_checker',
|
||||
'--safety_checker',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='safety_checker',
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoimport',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--from_file',
|
||||
dest='infile',
|
||||
type=str,
|
||||
help='If specified, load prompts from this file',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||
default='outputs',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--prompt_as_dir',
|
||||
'-p',
|
||||
action='store_true',
|
||||
help='Place images in subdirectories named after the prompt.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--fnformat',
|
||||
default='{prefix}.{seed}.png',
|
||||
type=str,
|
||||
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-s',
|
||||
'--steps',
|
||||
type=int,
|
||||
default=50,
|
||||
help='Number of steps'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-W',
|
||||
'--width',
|
||||
type=int,
|
||||
help='Image width, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-H',
|
||||
'--height',
|
||||
type=int,
|
||||
help='Image height, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-C',
|
||||
'--cfg_scale',
|
||||
default=7.5,
|
||||
type=float,
|
||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--sampler',
|
||||
'-A',
|
||||
'-m',
|
||||
dest='sampler_name',
|
||||
type=str,
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar='SAMPLER_NAME',
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default='k_lms',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--log_tokenization',
|
||||
'-t',
|
||||
action='store_true',
|
||||
help='shows how the prompt is split into tokens'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-f',
|
||||
'--strength',
|
||||
type=float,
|
||||
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
'--fit',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||
)
|
||||
|
||||
render_group.add_argument(
|
||||
'--grid',
|
||||
'-g',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='generate a grid'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embedding_directory',
|
||||
'--embedding_path',
|
||||
dest='embedding_path',
|
||||
default='embeddings',
|
||||
type=str,
|
||||
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--lora_directory',
|
||||
dest='lora_path',
|
||||
default='loras',
|
||||
type=str,
|
||||
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embeddings',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--enable_image_debugging',
|
||||
action='store_true',
|
||||
help='Generates debugging image to display'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--karras_max',
|
||||
type=int,
|
||||
default=None,
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||
)
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
'--no_restore',
|
||||
dest='restore',
|
||||
action='store_false',
|
||||
help='Disable face restoration with GFPGAN or codeformer',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--no_upscale',
|
||||
dest='esrgan',
|
||||
action='store_false',
|
||||
help='Disable upscaling with ESRGAN',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_bg_tile',
|
||||
type=int,
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_denoise_str',
|
||||
type=float,
|
||||
default=0.75,
|
||||
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_model_path',
|
||||
type=str,
|
||||
default='./models/gfpgan/GFPGANv1.4.pth',
|
||||
help='Indicates the path to the GFPGAN model',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web',
|
||||
dest='web',
|
||||
action='store_true',
|
||||
help='Start in web server mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web_develop',
|
||||
dest='web_develop',
|
||||
action='store_true',
|
||||
help='Start in web server development mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_verbose",
|
||||
action="store_true",
|
||||
help="Enables verbose logging",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--cors",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--host',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default='9090',
|
||||
help='Web server: Port to listen on'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--certfile',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--keyfile',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--gui',
|
||||
dest='gui',
|
||||
action='store_true',
|
||||
help='Start InvokeAI GUI',
|
||||
)
|
||||
@@ -6,26 +6,30 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass,field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import List
|
||||
from typing import List, Dict, Callable
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
from huggingface_hub import hf_hub_url, HfFolder
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from ..model_management import ModelManager
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
@@ -35,6 +39,9 @@ Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
# logger
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
@@ -45,58 +52,98 @@ Config_preamble = """
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class ModelInstallList:
|
||||
'''Class for listing models to be installed/removed'''
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class UserSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
purge_deleted_models: bool=field(default_factory=list)
|
||||
install_cn_models: List[str] = field(default_factory=list)
|
||||
remove_cn_models: List[str] = field(default_factory=list)
|
||||
install_lora_models: List[str] = field(default_factory=list)
|
||||
remove_lora_models: List[str] = field(default_factory=list)
|
||||
install_ti_models: List[str] = field(default_factory=list)
|
||||
remove_ti_models: List[str] = field(default_factory=list)
|
||||
scan_directory: Path = None
|
||||
autoscan_on_startup: bool=False
|
||||
import_model_paths: str=None
|
||||
|
||||
def default_config_file():
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
|
||||
return config.model_conf_path
|
||||
|
||||
def sd_configs():
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
return config.legacy_conf_path
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
if Datasets:
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path))
|
||||
|
||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
diffusers: ModelInstallList = None,
|
||||
controlnet: ModelInstallList = None,
|
||||
lora: ModelInstallList = None,
|
||||
ti: ModelInstallList = None,
|
||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
model_config_file_callback: Callable[[Path],Path] = None
|
||||
):
|
||||
"""
|
||||
Entry point for installing/deleting starter models, or installing external models.
|
||||
"""
|
||||
access_token = HfFolder.get_token()
|
||||
config_file_path = config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path, "w")
|
||||
|
||||
# prevent circular import here
|
||||
from ..model_management import ModelManager
|
||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
||||
if controlnet:
|
||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
if lora:
|
||||
model_manager.install_lora_models(lora.install_models, access_token=access_token)
|
||||
model_manager.delete_lora_models(lora.remove_models)
|
||||
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
if ti:
|
||||
model_manager.install_ti_models(ti.install_models, access_token=access_token)
|
||||
model_manager.delete_ti_models(ti.remove_models)
|
||||
|
||||
if diffusers:
|
||||
# TODO: Replace next three paragraphs with calls into new model manager
|
||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
||||
logger.info("Processing requested deletions")
|
||||
for model in diffusers.remove_models:
|
||||
logger.info(f"{model}...")
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
||||
logger.info("Installing requested models")
|
||||
downloaded_paths = download_weight_datasets(
|
||||
models=diffusers.install_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
)
|
||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
||||
if len(successful) > 0:
|
||||
update_config_file(successful, config_file_path)
|
||||
if len(successful) < len(diffusers.install_models):
|
||||
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
||||
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
@@ -107,12 +154,14 @@ def install_requested_models(
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models) > 0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
logger.info("INSTALLING EXTERNAL MODELS")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
commit_to_conf=config_file_path,
|
||||
config_file_callback = model_config_file_callback,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
@@ -120,17 +169,22 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = "--autoconvert"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
with open(initfile, "r") as input:
|
||||
with open(replacement, "w") as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f"{argument} {directory}"])
|
||||
os.replace(replacement, initfile)
|
||||
update_autoconvert_dir(scan_directory)
|
||||
else:
|
||||
update_autoconvert_dir(None)
|
||||
|
||||
def update_autoconvert_dir(autodir: Path):
|
||||
'''
|
||||
Update the "autoconvert_dir" option in invokeai.yaml
|
||||
'''
|
||||
invokeai_config_path = config.init_file_path
|
||||
conf = OmegaConf.load(invokeai_config_path)
|
||||
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
|
||||
yaml = OmegaConf.to_yaml(conf)
|
||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(yaml)
|
||||
tmpfile.replace(invokeai_config_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
@@ -142,33 +196,21 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
def recommended_datasets() -> List['str']:
|
||||
datasets = set()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
datasets.add(ds)
|
||||
return list(datasets)
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
datasets = set()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
datasets.add(ds)
|
||||
return list(datasets)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -183,14 +225,14 @@ def all_datasets() -> dict:
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print(
|
||||
logger.warning(
|
||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
||||
)
|
||||
print(f"model.ckpt => {new_name}")
|
||||
logger.warning(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
@@ -203,7 +245,7 @@ def download_weight_datasets(
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
print(f"Downloading {mod}:")
|
||||
logger.info(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
@@ -224,11 +266,10 @@ def _download_repo_or_file(
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
@@ -239,9 +280,12 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
model_class: object, model_name: str, **kwargs
|
||||
):
|
||||
path = global_cache_dir(cache_subdir)
|
||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||
|
||||
path = config.cache_dir
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
@@ -272,10 +316,10 @@ def _download_diffusion_weights(
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
if 'Revision Not Found' in str(e):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
logger.error(str(e))
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
@@ -283,9 +327,13 @@ def _download_diffusion_weights(
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
@@ -305,20 +353,19 @@ def hf_download_with_resume(
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
@@ -331,7 +378,7 @@ def hf_download_with_resume(
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
@@ -356,8 +403,8 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
logger.warning(
|
||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
@@ -374,16 +421,16 @@ def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
logger.info("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
logger.info(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@@ -417,7 +464,7 @@ def new_config_file_contents(
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
successfully_downloaded[model], start=config.root_dir
|
||||
)
|
||||
stanza["config"] = os.path.normpath(
|
||||
os.path.join(sd_configs(), mod["config"])
|
||||
@@ -450,14 +497,14 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
logger.warning(
|
||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
weights = config.root_dir / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
logger.error(str(e))
|
||||
@@ -25,7 +25,8 @@ from typing import Union
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
|
||||
@@ -46,6 +47,7 @@ from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -72,7 +74,6 @@ from transformers import (
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
@@ -372,9 +373,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(" | Extracting EMA weights (usually better for inference)")
|
||||
logger.debug("Extracting EMA weights (usually better for inference)")
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
@@ -392,8 +393,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
key
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
logger.debug(
|
||||
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
@@ -841,7 +842,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
|
||||
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -896,7 +897,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
cache_dir = global_cache_dir("hub")
|
||||
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||
config = CLIPVisionConfig.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
@@ -968,7 +969,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
cache_dir = global_cache_dir("hub")
|
||||
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||
)
|
||||
@@ -1091,6 +1092,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
:param vae: A diffusers VAE to load into the pipeline.
|
||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
cache_dir = config.cache_dir
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
@@ -1104,7 +1107,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
||||
cache_dir = global_cache_dir("hub")
|
||||
pipeline_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if return_generator_pipeline
|
||||
@@ -1115,7 +1117,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print(" | global_step key not found in model")
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
@@ -1128,25 +1130,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if model_type == SDLegacyType.V2_v:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
|
||||
config.legacy_conf_path / "v2-inference-v.yaml"
|
||||
)
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
|
||||
config.legacy_conf_path / "v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
original_config_file = (
|
||||
global_config_dir()
|
||||
/ "stable-diffusion"
|
||||
/ "v1-inpainting-inference.yaml"
|
||||
config.legacy_conf_path / "v1-inpainting-inference.yaml"
|
||||
)
|
||||
|
||||
elif model_type == SDLegacyType.V1:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
|
||||
config.legacy_conf_path / "v1-inference.yaml"
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -1208,6 +1208,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
@@ -1229,15 +1231,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
print(f" | Converting VAE {vae_path}")
|
||||
logger.debug(f"Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
print(" | Using checkpoint model's original VAE")
|
||||
logger.debug("Using checkpoint model's original VAE")
|
||||
|
||||
if vae:
|
||||
print(" | Using replacement diffusers VAE")
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
@@ -1296,7 +1298,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker",
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||
|
||||
@@ -11,32 +11,33 @@ import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
from typing import Any, Optional, Union, Callable, Dict, List, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
import invokeai.backend.util.logging as logger
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as dlogging,
|
||||
)
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
@@ -48,9 +49,13 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..install.model_install_backend import (
|
||||
Dataset_path,
|
||||
hf_download_with_resume,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = auto()
|
||||
V1_INPAINT = auto()
|
||||
@@ -67,7 +72,7 @@ class SDModelComponent(Enum):
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
@@ -75,6 +80,8 @@ class ModelManager(object):
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
"""
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
@@ -83,6 +90,7 @@ class ModelManager(object):
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
@@ -96,6 +104,7 @@ class ModelManager(object):
|
||||
if not isinstance(config, DictConfig):
|
||||
config = OmegaConf.load(config)
|
||||
self.config = config
|
||||
self.globals = InvokeAIAppConfig.get_config()
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.max_loaded_models = max_loaded_models
|
||||
@@ -104,6 +113,7 @@ class ModelManager(object):
|
||||
self.current_model = None
|
||||
self.sequential_offload = sequential_offload
|
||||
self.embedding_path = embedding_path
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
@@ -132,8 +142,8 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return self.current_model
|
||||
|
||||
@@ -144,7 +154,7 @@ class ModelManager(object):
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
@@ -177,7 +187,7 @@ class ModelManager(object):
|
||||
vae from the model currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||
|
||||
|
||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||
@@ -185,12 +195,12 @@ class ModelManager(object):
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||
|
||||
|
||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||
name is provided, return the UNet from the model
|
||||
currently in the GPU.
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||
|
||||
@@ -217,7 +227,7 @@ class ModelManager(object):
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||
|
||||
|
||||
def _get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
@@ -287,7 +297,7 @@ class ModelManager(object):
|
||||
"""
|
||||
# if we are converting legacy files automatically, then
|
||||
# there are no legacy ckpts!
|
||||
if Globals.ckpt_convert:
|
||||
if self.globals.ckpt_convert:
|
||||
return False
|
||||
info = self.model_info(model_name)
|
||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||
@@ -310,7 +320,7 @@ class ModelManager(object):
|
||||
models = {}
|
||||
for name in sorted(self.config, key=str.casefold):
|
||||
stanza = self.config[name]
|
||||
|
||||
|
||||
# don't include VAEs in listing (legacy style)
|
||||
if "config" in stanza and "/VAE/" in stanza["config"]:
|
||||
continue
|
||||
@@ -379,7 +389,7 @@ class ModelManager(object):
|
||||
"""
|
||||
omega = self.config
|
||||
if model_name not in omega:
|
||||
print(f"** Unknown model {model_name}")
|
||||
self.logger.error(f"Unknown model {model_name}")
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
@@ -392,13 +402,13 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
print(f"** Deleting file {weights}")
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
print(f"** Deleting directory {path}")
|
||||
self.logger.info(f"Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
elif repo_id:
|
||||
print(f"** Deleting the cached model directory for {repo_id}")
|
||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(
|
||||
@@ -439,7 +449,7 @@ class ModelManager(object):
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return
|
||||
@@ -457,7 +467,7 @@ class ModelManager(object):
|
||||
model_format = mconfig.get("format", "ckpt")
|
||||
if model_format == "ckpt":
|
||||
weights = mconfig.weights
|
||||
print(f">> Loading {model_name} from {weights}")
|
||||
self.logger.info(f"Loading {model_name} from {weights}")
|
||||
model, width, height, model_hash = self._load_ckpt_model(
|
||||
model_name, mconfig
|
||||
)
|
||||
@@ -473,13 +483,15 @@ class ModelManager(object):
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
">> Max VRAM used to load the model:",
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
"\n>> Current VRAM usage:"
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
self.logger.info(
|
||||
"Max VRAM used to load the model: "+
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||
)
|
||||
self.logger.info(
|
||||
"Current VRAM usage: "+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
@@ -487,21 +499,21 @@ class ModelManager(object):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
print(f">> Loading diffusers model from {name_or_path}")
|
||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
||||
if using_fp16:
|
||||
print(" | Using faster float16 precision")
|
||||
self.logger.debug("Using faster float16 precision")
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
safety_checker=None, local_files_only=not Globals.internet_available
|
||||
safety_checker=None, local_files_only=not self.globals.internet_available
|
||||
)
|
||||
if "vae" in mconfig and mconfig["vae"] is not None:
|
||||
if vae := self._load_vae(mconfig["vae"]):
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path, Path):
|
||||
pipeline_args.update(cache_dir=global_cache_dir("hub"))
|
||||
pipeline_args.update(cache_dir=self.globals.cache_dir)
|
||||
if using_fp16:
|
||||
pipeline_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
@@ -520,11 +532,12 @@ class ModelManager(object):
|
||||
**fp_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
if str(e).startswith("fp16 is not a valid") or \
|
||||
'Invalid rev id: fp16' in str(e):
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"** An unexpected error occurred while downloading the model: {e})"
|
||||
self.logger.error(
|
||||
f"An unexpected error occurred while downloading the model: {e})"
|
||||
)
|
||||
if pipeline:
|
||||
break
|
||||
@@ -542,7 +555,7 @@ class ModelManager(object):
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
@@ -553,29 +566,24 @@ class ModelManager(object):
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
root_dir = self.globals.root_dir
|
||||
config = str(root_dir / config)
|
||||
weights = str(root_dir / weights)
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vae_path = None
|
||||
if vae:
|
||||
vae_path = (
|
||||
vae
|
||||
if os.path.isabs(vae)
|
||||
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
)
|
||||
vae_path = str(root_dir / vae)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@@ -607,9 +615,7 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
if "path" in mconfig and mconfig["path"] is not None:
|
||||
path = Path(mconfig["path"])
|
||||
if not path.is_absolute():
|
||||
path = Path(Globals.root, path).resolve()
|
||||
path = self.globals.root_dir / Path(mconfig["path"])
|
||||
return path
|
||||
elif "repo_id" in mconfig:
|
||||
return mconfig["repo_id"]
|
||||
@@ -624,7 +630,7 @@ class ModelManager(object):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
print(f">> Offloading {model_name} to CPU")
|
||||
self.logger.info(f"Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.current_model = None
|
||||
@@ -640,30 +646,26 @@ class ModelManager(object):
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
print(f" | Scanning Model: {model_name}")
|
||||
self.logger.debug(f"Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
|
||||
print(
|
||||
"### WARNING: The model you are trying to load seems to be infected."
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this model.")
|
||||
print("### Please use checkpoints from trusted sources.")
|
||||
print("### Exiting InvokeAI")
|
||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(
|
||||
"\n### WARNING: InvokeAI was unable to scan the model you are using."
|
||||
)
|
||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||
model_safe_check_fail = ask_user(
|
||||
"Do you want to to continue loading the model?", ["y", "n"]
|
||||
)
|
||||
if model_safe_check_fail.lower() != "y":
|
||||
print("### Exiting InvokeAI")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(" | Model scanned ok")
|
||||
self.logger.debug("Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@@ -778,28 +780,26 @@ class ModelManager(object):
|
||||
|
||||
"""
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
thing = str(path_url_or_repo) # to save typing
|
||||
|
||||
print(f">> Probing {thing} for import")
|
||||
self.logger.info(f"Probing {thing} for import")
|
||||
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
print(f" | {thing} appears to be a URL")
|
||||
if str(thing).startswith(("http:", "https:", "ftp:")):
|
||||
self.logger.info(f"{thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
print(
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
)
|
||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||
return
|
||||
else:
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
@@ -810,34 +810,32 @@ class ModelManager(object):
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers model.")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||
)
|
||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m), commit_to_conf=commit_to_conf
|
||||
str(m),
|
||||
commit_to_conf=commit_to_conf,
|
||||
config_file_callback=config_file_callback,
|
||||
):
|
||||
print(f" >> {model_name} successfully imported")
|
||||
self.logger.info(f"{model_name} successfully imported")
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||
return model_name
|
||||
else:
|
||||
print(
|
||||
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
||||
)
|
||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
|
||||
# Model_path is set in the event of a legacy checkpoint file.
|
||||
# If not set, we're all done
|
||||
@@ -845,7 +843,7 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: # already imported
|
||||
print(" | Already imported. Skipping")
|
||||
self.logger.debug("Already imported. Skipping")
|
||||
return model_path.stem
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
@@ -861,41 +859,30 @@ class ModelManager(object):
|
||||
# look for a like-named .yaml file in same directory
|
||||
if model_path.with_suffix(".yaml").exists():
|
||||
model_config_file = model_path.with_suffix(".yaml")
|
||||
print(f" | Using config file {model_config_file.name}")
|
||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
||||
|
||||
else:
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
self.logger.debug("SD-v1 model detected")
|
||||
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
self.logger.debug("SD-v1 inpainting model detected")
|
||||
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml"
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(" | SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
self.logger.debug("SD-v2-v model detected")
|
||||
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(" | SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
self.logger.debug("SD-v2-e model detected")
|
||||
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
self.logger.warning(
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined."
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
self.logger.warning(
|
||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model."
|
||||
)
|
||||
return
|
||||
|
||||
if not model_config_file and config_file_callback:
|
||||
model_config_file = config_file_callback(model_path)
|
||||
@@ -909,12 +896,10 @@ class ModelManager(object):
|
||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||
print(f" | Using VAE file {vae_path.name}")
|
||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
diffuser_path = self.globals.root_dir / "models/converted_ckpts" / model_path.stem
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
@@ -954,35 +939,36 @@ class ModelManager(object):
|
||||
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
print(
|
||||
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
)
|
||||
return
|
||||
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or f"Converted version of {model_name}"
|
||||
print(f" | Converting {model_name} to diffusers (30-60s)")
|
||||
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_model = self._load_vae(vae)
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
original_config_file=original_config_file,
|
||||
vae=vae_model,
|
||||
vae_path=vae_path,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
print(
|
||||
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
||||
)
|
||||
print(f" | Writing new config file entry for {model_name}")
|
||||
if diffusers_path.exists():
|
||||
self.logger.error(
|
||||
f"The path {str(diffusers_path)} already exists. Installing previously-converted path."
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_model = self._load_vae(vae)
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
original_config_file=original_config_file,
|
||||
vae=vae_model,
|
||||
vae_path=vae_path,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
||||
)
|
||||
self.logger.debug(f"Writing new config file entry for {model_name}")
|
||||
new_config = dict(
|
||||
path=str(diffusers_path),
|
||||
description=model_description,
|
||||
@@ -993,17 +979,18 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
print(" | Conversion succeeded")
|
||||
self.logger.debug(f"Model {model_name} installed")
|
||||
except Exception as e:
|
||||
print(f"** Conversion failed: {str(e)}")
|
||||
print(
|
||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
self.logger.warning(f"Conversion failed: {str(e)}")
|
||||
self.logger.warning(traceback.format_exc())
|
||||
self.logger.warning(
|
||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f">> Finding Models In: {search_folder}")
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||
|
||||
@@ -1027,8 +1014,8 @@ class ModelManager(object):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(
|
||||
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
self.logger.info(
|
||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
)
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
@@ -1036,8 +1023,8 @@ class ModelManager(object):
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
print(
|
||||
">> Current VRAM usage: ",
|
||||
self.logger.info(
|
||||
"Current VRAM usage:"+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
@@ -1047,9 +1034,7 @@ class ModelManager(object):
|
||||
"""
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
if not os.path.isabs(config_file_path):
|
||||
config_file_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config_file_path)
|
||||
)
|
||||
config_file_path = self.globals.model_conf_path
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
@@ -1081,7 +1066,8 @@ class ModelManager(object):
|
||||
"""
|
||||
# Three transformer models to check: bert, clip and safety checker, and
|
||||
# the diffusers as well
|
||||
models_dir = Path(Globals.root, "models")
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
models_dir = config.root_dir / "models"
|
||||
legacy_locations = [
|
||||
Path(
|
||||
models_dir,
|
||||
@@ -1093,8 +1079,8 @@ class ModelManager(object):
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||
|
||||
legacy_cache_dir = config.cache_dir / "../diffusers"
|
||||
legacy_locations.extend(list(legacy_cache_dir.glob("*")))
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or model.exists()
|
||||
@@ -1116,7 +1102,7 @@ class ModelManager(object):
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
hub = global_cache_dir("hub")
|
||||
hub = config.cache_dir
|
||||
else:
|
||||
hub = models_dir / "hub"
|
||||
|
||||
@@ -1126,10 +1112,10 @@ class ModelManager(object):
|
||||
dest = hub / model.stem
|
||||
if dest.exists() and not source.exists():
|
||||
continue
|
||||
print(f"** {source} => {dest}")
|
||||
cls.logger.info(f"{source} => {dest}")
|
||||
if source.exists():
|
||||
if dest.is_symlink():
|
||||
print(f"** Found symlink at {dest.name}. Not migrating.")
|
||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||
elif dest.exists():
|
||||
if source.is_dir():
|
||||
rmtree(source)
|
||||
@@ -1146,7 +1132,7 @@ class ModelManager(object):
|
||||
]
|
||||
for d in empty:
|
||||
os.rmdir(d)
|
||||
print("** Migration is done. Continuing...")
|
||||
cls.logger.info("Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
@@ -1155,13 +1141,12 @@ class ModelManager(object):
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
dest_directory = Path(dest_directory)
|
||||
if not dest_directory.is_absolute():
|
||||
dest_directory = Globals.root / dest_directory
|
||||
dest_directory = self.globals.root_dir / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root, source)
|
||||
resolved_path = Path(source)
|
||||
source = self.globals.root_dir / source
|
||||
resolved_path = source
|
||||
return resolved_path
|
||||
|
||||
def _invalidate_cached_model(self, model_name: str) -> None:
|
||||
@@ -1189,15 +1174,15 @@ class ModelManager(object):
|
||||
|
||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||
if self.embedding_path is not None:
|
||||
print(f">> Loading embeddings from {self.embedding_path}")
|
||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
self.logger.info(
|
||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
@@ -1211,7 +1196,7 @@ class ModelManager(object):
|
||||
path = name_or_path
|
||||
else:
|
||||
owner, repo = name_or_path.split("/")
|
||||
path = Path(global_cache_dir("hub") / f"models--{owner}--{repo}")
|
||||
path = self.globals.cache_dir / f"models--{owner}--{repo}"
|
||||
if not path.exists():
|
||||
return None
|
||||
hashpath = path / "checksum.sha256"
|
||||
@@ -1219,7 +1204,7 @@ class ModelManager(object):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(" | Calculating sha256 hash of model files")
|
||||
self.logger.debug("Calculating sha256 hash of model files")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
@@ -1231,7 +1216,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
@@ -1249,13 +1234,13 @@ class ModelManager(object):
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(" | Calculating sha256 hash of weights file")
|
||||
self.logger.debug("Calculating sha256 hash of weights file")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
@@ -1272,16 +1257,16 @@ class ModelManager(object):
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
vae_args.update(
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
local_files_only=not Globals.internet_available,
|
||||
cache_dir=self.globals.cache_dir,
|
||||
local_files_only=not self.globals.internet_available,
|
||||
)
|
||||
|
||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
@@ -1305,13 +1290,13 @@ class ModelManager(object):
|
||||
break
|
||||
|
||||
if not vae and deferred_error:
|
||||
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
# but the code quickly became incomprehensible!
|
||||
@@ -1321,19 +1306,202 @@ class ModelManager(object):
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
print(
|
||||
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
cls.logger.warning(
|
||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
)
|
||||
strategy.execute()
|
||||
|
||||
@staticmethod
|
||||
def _abs_path(path: str | Path) -> Path:
|
||||
globals = InvokeAIAppConfig.get_config()
|
||||
if path is None or Path(path).is_absolute():
|
||||
return path
|
||||
return Path(Globals.root, path).resolve()
|
||||
return Path(globals.root_dir, path).resolve()
|
||||
|
||||
@staticmethod
|
||||
def _is_huggingface_hub_directory_present() -> bool:
|
||||
return (
|
||||
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
||||
)
|
||||
|
||||
def list_lora_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed lora models; key is either the shortname
|
||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
||||
directory. Value is True if installed'''
|
||||
|
||||
models = OmegaConf.load(Dataset_path).get('lora') or {}
|
||||
installed_models = {x: False for x in models.keys()}
|
||||
|
||||
dir = self.globals.lora_path
|
||||
installed_models = dict()
|
||||
for root, dirs, files in os.walk(dir):
|
||||
for name in files:
|
||||
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
|
||||
continue
|
||||
if name == 'pytorch_lora_weights.bin':
|
||||
name = Path(root,name).parent.stem #Path(root,name).stem
|
||||
else:
|
||||
name = Path(name).stem
|
||||
installed_models.update({name: True})
|
||||
|
||||
return installed_models
|
||||
|
||||
def install_lora_models(self, model_names: list[str], access_token:str=None):
|
||||
'''Download list of LoRA/LyCORIS models'''
|
||||
|
||||
short_names = OmegaConf.load(Dataset_path).get('lora') or {}
|
||||
for name in model_names:
|
||||
name = short_names.get(name) or name
|
||||
|
||||
# HuggingFace style LoRA
|
||||
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
|
||||
self.logger.info(f'Downloading LoRA/LyCORIS model {name}')
|
||||
_,dest_dir = name.split("/")
|
||||
|
||||
hf_download_with_resume(
|
||||
repo_id = name,
|
||||
model_dir = self.globals.lora_path / dest_dir,
|
||||
model_name = 'pytorch_lora_weights.bin',
|
||||
access_token = access_token,
|
||||
)
|
||||
|
||||
elif name.startswith(("http:", "https:", "ftp:")):
|
||||
download_with_resume(name, self.globals.lora_path)
|
||||
|
||||
else:
|
||||
self.logger.error(f"Unknown repo_id or URL: {name}")
|
||||
|
||||
def delete_lora_models(self, model_names: List[str]):
|
||||
'''Remove the list of lora models'''
|
||||
for name in model_names:
|
||||
file_or_directory = self.globals.lora_path / name
|
||||
if file_or_directory.is_dir():
|
||||
self.logger.info(f'Purging LoRA/LyCORIS {name}')
|
||||
shutil.rmtree(str(file_or_directory))
|
||||
else:
|
||||
for path in self.globals.lora_path.glob(f'{name}.*'):
|
||||
self.logger.info(f'Purging LoRA/LyCORIS {name}')
|
||||
path.unlink()
|
||||
|
||||
def list_ti_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed textual models; key is either the shortname
|
||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
||||
directory. Value is True if installed'''
|
||||
|
||||
models = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
|
||||
installed_models = {x: False for x in models.keys()}
|
||||
|
||||
dir = self.globals.embedding_path
|
||||
for root, dirs, files in os.walk(dir):
|
||||
for name in files:
|
||||
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
|
||||
continue
|
||||
if name == 'learned_embeds.bin':
|
||||
name = Path(root,name).parent.stem #Path(root,name).stem
|
||||
else:
|
||||
name = Path(name).stem
|
||||
installed_models.update({name: True})
|
||||
return installed_models
|
||||
|
||||
def install_ti_models(self, model_names: list[str], access_token: str=None):
|
||||
'''Download list of textual inversion embeddings'''
|
||||
|
||||
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
|
||||
for name in model_names:
|
||||
name = short_names.get(name) or name
|
||||
|
||||
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
|
||||
self.logger.info(f'Downloading Textual Inversion embedding {name}')
|
||||
_,dest_dir = name.split("/")
|
||||
hf_download_with_resume(
|
||||
repo_id = name,
|
||||
model_dir = self.globals.embedding_path / dest_dir,
|
||||
model_name = 'learned_embeds.bin',
|
||||
access_token = access_token
|
||||
)
|
||||
elif name.startswith(('http:','https:','ftp:')):
|
||||
download_with_resume(name, self.globals.embedding_path)
|
||||
else:
|
||||
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
|
||||
|
||||
def delete_ti_models(self, model_names: list[str]):
|
||||
'''Remove TI embeddings from disk'''
|
||||
for name in model_names:
|
||||
file_or_directory = self.globals.embedding_path / name
|
||||
if file_or_directory.is_dir():
|
||||
self.logger.info(f'Purging textual inversion embedding {name}')
|
||||
shutil.rmtree(str(file_or_directory))
|
||||
else:
|
||||
for path in self.globals.embedding_path.glob(f'{name}.*'):
|
||||
self.logger.info(f'Purging textual inversion embedding {name}')
|
||||
path.unlink()
|
||||
|
||||
def list_controlnet_models(self)->Dict[str,bool]:
|
||||
'''Return a dict of installed controlnet models; key is repo_id or short name
|
||||
of model (defined in INITIAL_MODELS), and value is True if installed'''
|
||||
|
||||
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
installed_models = {x: False for x in cn_models.keys()}
|
||||
|
||||
cn_dir = self.globals.controlnet_path
|
||||
for root, dirs, files in os.walk(cn_dir):
|
||||
for name in dirs:
|
||||
if Path(root, name, '.download_complete').exists():
|
||||
installed_models.update({name.replace('--','/'): True})
|
||||
return installed_models
|
||||
|
||||
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
|
||||
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
|
||||
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
|
||||
dest_dir = self.globals.controlnet_path
|
||||
dest_dir.mkdir(parents=True,exist_ok=True)
|
||||
|
||||
# The model file may be fp32 or fp16, and may be either a
|
||||
# .bin file or a .safetensors. We try each until we get one,
|
||||
# preferring 'fp16' if using half precision, and preferring
|
||||
# safetensors over over bin.
|
||||
precisions = ['.fp16',''] if self.precision=='float16' else ['']
|
||||
formats = ['.safetensors','.bin']
|
||||
possible_filenames = list()
|
||||
for p in precisions:
|
||||
for f in formats:
|
||||
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
|
||||
|
||||
for directory_name in model_names:
|
||||
repo_id = short_names.get(directory_name) or directory_name
|
||||
safe_name = directory_name.replace('/','--')
|
||||
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
|
||||
hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = 'config.json',
|
||||
access_token = access_token
|
||||
)
|
||||
|
||||
path = None
|
||||
for filename in possible_filenames:
|
||||
suffix = filename.suffix
|
||||
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
|
||||
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
|
||||
path = hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_dir = dest_dir / safe_name,
|
||||
model_name = str(filename),
|
||||
access_token = access_token,
|
||||
model_dest = Path(dest_dir, safe_name, dest_filename),
|
||||
)
|
||||
if path:
|
||||
(path.parent / '.download_complete').touch()
|
||||
break
|
||||
|
||||
def delete_controlnet_models(self, model_names: List[str]):
|
||||
'''Remove the list of controlnet models'''
|
||||
for name in model_names:
|
||||
safe_name = name.replace('/','--')
|
||||
directory = self.globals.controlnet_path / safe_name
|
||||
if directory.exists():
|
||||
self.logger.info(f'Purging controlnet model {name}')
|
||||
shutil.rmtree(str(directory))
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,66 +16,59 @@ from compel.prompt_parser import (
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def get_uc_and_c_and_ec(
|
||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
||||
):
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_string
|
||||
)
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False
|
||||
)
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or config.log_tokenization:
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
@@ -86,18 +79,17 @@ def get_prompt_structure(
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
@@ -162,8 +154,8 @@ def log_tokenization(
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
@@ -237,29 +229,28 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
logger.debug(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(
|
||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
||||
)
|
||||
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
@@ -291,12 +282,14 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if len(parsed_prompts) == 0:
|
||||
return []
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
logger.warning(
|
||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Restoration:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
@@ -8,17 +10,17 @@ class Restoration:
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print(">> GFPGAN Initialized")
|
||||
logger.info("GFPGAN Initialized")
|
||||
else:
|
||||
print(">> GFPGAN Disabled")
|
||||
logger.info("GFPGAN Disabled")
|
||||
gfpgan = None
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print(">> CodeFormer Initialized")
|
||||
logger.info("CodeFormer Initialized")
|
||||
else:
|
||||
print(">> CodeFormer Disabled")
|
||||
logger.info("CodeFormer Disabled")
|
||||
codeformer = None
|
||||
|
||||
return gfpgan, codeformer
|
||||
@@ -39,5 +41,5 @@ class Restoration:
|
||||
from .realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print(">> ESRGAN Initialized")
|
||||
logger.info("ESRGAN Initialized")
|
||||
return esrgan
|
||||
|
||||
@@ -5,7 +5,8 @@ import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..globals import Globals
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
pretrained_model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
@@ -16,19 +17,19 @@ class CodeFormerRestoration:
|
||||
def __init__(
|
||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||
) -> None:
|
||||
if not os.path.isabs(codeformer_dir):
|
||||
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
||||
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
self.globals = InvokeAIAppConfig.get_config()
|
||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
||||
self.model_path = codeformer_dir / codeformer_model_path
|
||||
self.codeformer_model_exists = self.model_path.exists()
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
@@ -70,9 +71,7 @@ class CodeFormerRestoration:
|
||||
upscale_factor=1,
|
||||
use_parse=True,
|
||||
device=device,
|
||||
model_rootpath=os.path.join(
|
||||
Globals.root, "models", "gfpgan", "weights"
|
||||
),
|
||||
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
|
||||
)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(bgr_image_array)
|
||||
@@ -97,7 +96,7 @@ class CodeFormerRestoration:
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f"\tFailed inference for CodeFormer: {error}.")
|
||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype("uint8")
|
||||
|
||||
@@ -6,20 +6,19 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
class GFPGAN:
|
||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||
self.globals = InvokeAIAppConfig.get_config()
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
gfpgan_model_path = os.path.abspath(
|
||||
os.path.join(Globals.root, gfpgan_model_path)
|
||||
)
|
||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
||||
self.model_path = gfpgan_model_path
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
return None
|
||||
|
||||
def model_exists(self):
|
||||
@@ -27,13 +26,13 @@ class GFPGAN:
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
cwd = os.getcwd()
|
||||
os.chdir(os.path.join(Globals.root, "models"))
|
||||
os.chdir(self.globals.root_dir / 'models')
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
@@ -47,14 +46,14 @@ class GFPGAN:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print(">> Error loading GFPGAN:", file=sys.stderr)
|
||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
os.chdir(cwd)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(f">> WARNING: GFPGAN not initialized.")
|
||||
print(
|
||||
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
logger.warning("WARNING: GFPGAN not initialized.")
|
||||
logger.warning(
|
||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
)
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
@@ -82,7 +82,7 @@ class Outcrop(object):
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels / 64) * 64
|
||||
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
||||
image = self._rotate(image, direction)
|
||||
image = self._extend(image, pixels)
|
||||
image = self._rotate(image, direction, reverse=True)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@@ -6,18 +5,14 @@ import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class ESRGAN:
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
@@ -35,12 +30,8 @@ class ESRGAN:
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
model_path = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
wdn_model_path = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
@@ -74,16 +65,16 @@ class ESRGAN:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
||||
logger.error("Error loading Real-ESRGAN:")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if upsampler_scale == 0:
|
||||
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
return image
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
logger.info(
|
||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
)
|
||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -14,9 +14,12 @@ from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from .globals import global_cache_dir
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .util import CPU_DEVICE
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class SafetyChecker(object):
|
||||
CAUTION_IMG = "caution.png"
|
||||
|
||||
@@ -25,10 +28,10 @@ class SafetyChecker(object):
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
self.device = device
|
||||
|
||||
|
||||
try:
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
safety_model_path = config.cache_dir
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
@@ -40,8 +43,8 @@ class SafetyChecker(object):
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
logger.error(
|
||||
"An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
|
||||
@@ -65,8 +68,8 @@ class SafetyChecker(object):
|
||||
)
|
||||
self.safety_checker.to(CPU_DEVICE) # offload
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
logger.warning(
|
||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
|
||||
@@ -17,15 +17,17 @@ from huggingface_hub import (
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
"""
|
||||
self.root = root or Globals.root
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.root = root or self.config.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
@@ -57,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
elif Globals.internet_available is True:
|
||||
elif self.config.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
@@ -66,11 +68,11 @@ class HuggingFaceConceptsLibrary(object):
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
print(
|
||||
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
print(
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
@@ -83,7 +85,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
print(
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
@@ -221,7 +223,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
print(f">> Downloading {repo_id}...", end="")
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
@@ -235,22 +237,22 @@ class HuggingFaceConceptsLibrary(object):
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
print(
|
||||
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
print("...{:.2f}Kb".format(bytes / 1024))
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
|
||||
@@ -2,10 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import math
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
@@ -38,8 +40,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..util import CPU_DEVICE, normalize_device
|
||||
from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
@@ -49,7 +50,6 @@ from .diffusion import (
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
run_id: str
|
||||
@@ -75,10 +75,10 @@ class AddsMaskLatents:
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(
|
||||
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor
|
||||
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings)
|
||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
@@ -214,12 +214,19 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
raise AssertionError("why was that an empty generator?")
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: Union[float, List[float]]= Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: torch.Tensor
|
||||
text_embeddings: torch.Tensor
|
||||
guidance_scale: float
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
@@ -357,10 +364,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and is_xformers_available()
|
||||
and not Globals.disable_xformers
|
||||
and not config.disable_xformers
|
||||
):
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
@@ -519,12 +527,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, device=self._model_group.device_for(self.unet)
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
@@ -537,6 +549,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance=additional_guidance,
|
||||
run_id=run_id,
|
||||
callback=callback,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
return result.latents, result.attention_map_saver
|
||||
@@ -550,6 +563,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id: str = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
@@ -559,8 +573,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
@@ -579,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t.fill_(t)
|
||||
step_output = self.step(
|
||||
@@ -589,6 +604,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
**kwargs,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
@@ -630,11 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@@ -642,32 +658,55 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
if (self.control_model is not None) and (kwargs.get("control_image") is not None):
|
||||
control_image = kwargs.get("control_image") # should be a processed tensor derived from the control image(s)
|
||||
control_scale = kwargs.get("control_scale", 1.0) # control_scale default is 1.0
|
||||
# handling case where using multiple control models but only specifying single control_scale
|
||||
# so reshape control_scale to match number of control models
|
||||
if isinstance(self.control_model, MultiControlNetModel) and isinstance(control_scale, float):
|
||||
control_scale = [control_scale] * len(self.control_model.nets)
|
||||
if conditioning_data.guidance_scale > 1.0:
|
||||
# default is no controlnet, so set controlnet processing output to None
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
||||
# if conditioning_data.guidance_scale > 1.0:
|
||||
if conditioning_data.guidance_scale is not None:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||
else:
|
||||
latent_control_input = latent_model_input
|
||||
# controlnet inference
|
||||
down_block_res_samples, mid_block_res_sample = self.control_model(
|
||||
latent_control_input,
|
||||
timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_image,
|
||||
conditioning_scale=control_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
# and MultiControlNet (multiple ControlNetData in list)
|
||||
for i, control_datum in enumerate(control_data):
|
||||
# print("controlnet", i, "==>", type(control_datum))
|
||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
# print("running controlnet", i, "for step", step_index)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=controlnet_weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
down_block_res_samples = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||
]
|
||||
mid_block_res_sample += mid_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||
@@ -678,8 +717,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@@ -773,11 +812,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps,
|
||||
strength,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||
@@ -803,13 +838,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(
|
||||
self, num_inference_steps: int, strength: float, device
|
||||
self, num_inference_steps: int, strength: float, device=None
|
||||
) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
num_inference_steps, strength, device=scheduler_device
|
||||
)
|
||||
# Workaround for low strength resulting in zero timesteps.
|
||||
# TODO: submit upstream fix for zero-step img2img
|
||||
@@ -843,9 +884,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
@@ -990,14 +1029,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
width=512,
|
||||
height=512,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents,
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
batch_size=1,
|
||||
num_images_per_prompt=1,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
):
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
@@ -9,12 +10,13 @@ import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
@@ -351,8 +353,7 @@ def restore_default_cross_attention(
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@@ -371,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next(
|
||||
(
|
||||
p.slice_size
|
||||
for p in old_attn_processors.values()
|
||||
if type(p) is SlicedAttnProcessor
|
||||
),
|
||||
default_slice_size,
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
def get_cross_attention_modules(
|
||||
model, which: CrossAttentionType
|
||||
@@ -420,7 +406,7 @@ def get_cross_attention_modules(
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .cross_attention_control import (
|
||||
Arguments,
|
||||
@@ -16,8 +18,8 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
override_cross_attention,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
@@ -30,7 +32,6 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
@@ -71,51 +72,65 @@ class InvokeAIDiffuserComponent:
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
cls,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
):
|
||||
do_swap = (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def override_cross_attention(
|
||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
) -> Dict[str, AttentionProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
return override_cross_attention(
|
||||
self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
)
|
||||
# apparently unused code
|
||||
# TODO: delete
|
||||
# def override_cross_attention(
|
||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
# ) -> Dict[str, AttentionProcessor]:
|
||||
# """
|
||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
# the previous attention processor is returned so that the caller can restore it later.
|
||||
# """
|
||||
# self.conditioning = conditioning
|
||||
# self.cross_attention_control_context = Context(
|
||||
# arguments=self.conditioning.cross_attention_control_args,
|
||||
# step_count=step_count,
|
||||
# )
|
||||
# return override_cross_attention(
|
||||
# self.model,
|
||||
# self.cross_attention_control_context,
|
||||
# is_running_diffusers=self.is_running_diffusers,
|
||||
# )
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||
@@ -165,7 +180,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor, dict],
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
unconditional_guidance_scale: float,
|
||||
# unconditional_guidance_scale: float,
|
||||
unconditional_guidance_scale: Union[float, List[float]],
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
@@ -180,6 +196,11 @@ class InvokeAIDiffuserComponent:
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
if isinstance(unconditional_guidance_scale, list):
|
||||
guidance_scale = unconditional_guidance_scale[step_index]
|
||||
else:
|
||||
guidance_scale = unconditional_guidance_scale
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
@@ -228,7 +249,8 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
unconditioned_next_x, conditioned_next_x, guidance_scale
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
@@ -476,10 +498,14 @@ class InvokeAIDiffuserComponent:
|
||||
outside = torch.count_nonzero(
|
||||
(latents < -current_threshold) | (latents > current_threshold)
|
||||
)
|
||||
print(
|
||||
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
logger.info(
|
||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
||||
)
|
||||
logger.debug(
|
||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||
)
|
||||
logger.debug(
|
||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
@@ -506,9 +532,11 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
if self.debug_thresholding:
|
||||
print(
|
||||
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
logger.debug(
|
||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
||||
)
|
||||
logger.debug(
|
||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
||||
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def mkdirs(paths):
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + "_archived_" + get_timestamp()
|
||||
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
os.replace(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user