mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 08:18:08 -05:00
Compare commits
962 Commits
nodes
...
feat/nodes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07c9b598bd | ||
|
|
4d37ce31fc | ||
|
|
20e853084f | ||
|
|
76fe1d0103 | ||
|
|
c4fad12ac1 | ||
|
|
db44d1431c | ||
|
|
8d79610be2 | ||
|
|
466980812b | ||
|
|
ca8f8b9162 | ||
|
|
8ae769fdad | ||
|
|
039f2b00df | ||
|
|
0299f0c4c6 | ||
|
|
df4abdcd81 | ||
|
|
ed8dfdf996 | ||
|
|
4004916af9 | ||
|
|
24638a71da | ||
|
|
c7392e7948 | ||
|
|
f92afaac7c | ||
|
|
6a30b8ec99 | ||
|
|
40a30f8fee | ||
|
|
fe07a0846e | ||
|
|
78acd185b5 | ||
|
|
57db0816d6 | ||
|
|
71347df765 | ||
|
|
751b4f249d | ||
|
|
5136b83049 | ||
|
|
6e4e0fe29e | ||
|
|
f0a9a4fb88 | ||
|
|
34b50e11b6 | ||
|
|
1d9c115225 | ||
|
|
30af20a056 | ||
|
|
cc21fb216c | ||
|
|
6fe62a2705 | ||
|
|
da87378713 | ||
|
|
b6f5267385 | ||
|
|
f9e78d3c64 | ||
|
|
b7b5bd1b46 | ||
|
|
9a3727d3ad | ||
|
|
d68c14516c | ||
|
|
9f4d39aa42 | ||
|
|
84b801d88f | ||
|
|
2fc70c509b | ||
|
|
34fb1c4b19 | ||
|
|
80bdd550cf | ||
|
|
2359b92b46 | ||
|
|
a404fb2d32 | ||
|
|
513eb11616 | ||
|
|
d2c9140e69 | ||
|
|
d95fe5925a | ||
|
|
835922ea8f | ||
|
|
e1e5266fc3 | ||
|
|
5e4457445f | ||
|
|
0221ca8f49 | ||
|
|
cf36e4029e | ||
|
|
c8a98a9a22 | ||
|
|
38ecca9362 | ||
|
|
c4681774a5 | ||
|
|
050add58d2 | ||
|
|
3d60c958c7 | ||
|
|
f5df150097 | ||
|
|
dac82adb5b | ||
|
|
b72c9787a9 | ||
|
|
2623941d91 | ||
|
|
d3a7fea939 | ||
|
|
5a7b687c84 | ||
|
|
0020457fc7 | ||
|
|
658b556544 | ||
|
|
37da0fc075 | ||
|
|
6d3e8507cc | ||
|
|
0e9470503f | ||
|
|
d2ebc6741b | ||
|
|
026d3260b4 | ||
|
|
78533714e3 | ||
|
|
691e1bf829 | ||
|
|
47a088d685 | ||
|
|
63db3fc22f | ||
|
|
ad0bb3f61a | ||
|
|
8f8cd90787 | ||
|
|
d796ea7bec | ||
|
|
e5b7dd63e9 | ||
|
|
af060188bd | ||
|
|
4270e7ae25 | ||
|
|
60a565d7de | ||
|
|
78cf70eaad | ||
|
|
eebaa50710 | ||
|
|
7d582553f2 | ||
|
|
4d6eea7e81 | ||
|
|
f44593331d | ||
|
|
3d9ecbf3c7 | ||
|
|
032aa1d59c | ||
|
|
35e0863bdb | ||
|
|
14070d674e | ||
|
|
108ce06c62 | ||
|
|
da364f3444 | ||
|
|
df5ba75c14 | ||
|
|
e4fb9cb33f | ||
|
|
65b527eb20 | ||
|
|
7dc9d18052 | ||
|
|
5013a4b9f3 | ||
|
|
f929359322 | ||
|
|
6522c71971 | ||
|
|
9c1e65f3a3 | ||
|
|
ebec200ba6 | ||
|
|
e559730b6e | ||
|
|
0acb8ed85d | ||
|
|
8c1c9cd702 | ||
|
|
0ece4686aa | ||
|
|
af95cef7f9 | ||
|
|
1eca7a918a | ||
|
|
9e6b958023 | ||
|
|
f7b99d93ae | ||
|
|
85d03dcd90 | ||
|
|
032555bcfe | ||
|
|
4caa1f19b2 | ||
|
|
95d4bd3012 | ||
|
|
037078c8ad | ||
|
|
6de2f66b50 | ||
|
|
cd7b248eda | ||
|
|
6d8c077f4e | ||
|
|
97127e560e | ||
|
|
27dc07d95a | ||
|
|
f7dc171c4f | ||
|
|
4b957edfec | ||
|
|
46ca7718d9 | ||
|
|
b928d7a6e6 | ||
|
|
8a836247c8 | ||
|
|
95c3644564 | ||
|
|
799cd07174 | ||
|
|
9af385468d | ||
|
|
3487388788 | ||
|
|
9a383e456d | ||
|
|
805f9f8f4a | ||
|
|
52aa0c9bbd | ||
|
|
7f5f4689cc | ||
|
|
a3f81f4b98 | ||
|
|
15c59e606f | ||
|
|
40d4cabecd | ||
|
|
3493c8119b | ||
|
|
c1e7460d39 | ||
|
|
3ffff023b2 | ||
|
|
f9384be59b | ||
|
|
6cf308004a | ||
|
|
d1029138d2 | ||
|
|
06b5800d28 | ||
|
|
483f2ccb56 | ||
|
|
93ced0bec6 | ||
|
|
4333852c37 | ||
|
|
3baa230077 | ||
|
|
9e594f9018 | ||
|
|
b0c41b4828 | ||
|
|
e0d6946b6b | ||
|
|
bf7ea8309f | ||
|
|
54b65f725f | ||
|
|
8ef49c2640 | ||
|
|
f488b1a7f2 | ||
|
|
d2edb7c402 | ||
|
|
f0a3f07b45 | ||
|
|
b42b630583 | ||
|
|
31a78d571b | ||
|
|
fdc2232ea0 | ||
|
|
e94d0b2d40 | ||
|
|
75ccbaee9c | ||
|
|
2848c8397c | ||
|
|
fe8b5193de | ||
|
|
3d1470399c | ||
|
|
fcf9c63049 | ||
|
|
7bfb5640ad | ||
|
|
15e57e3a3d | ||
|
|
279468c0e8 | ||
|
|
c565812723 | ||
|
|
ec6c8e2a38 | ||
|
|
77f2690711 | ||
|
|
c4b3a24ed7 | ||
|
|
33c69359c2 | ||
|
|
864f4bb4af | ||
|
|
5365f42a04 | ||
|
|
3dc60254b9 | ||
|
|
027a8562d7 | ||
|
|
34f3a0f0e3 | ||
|
|
d0bac1675e | ||
|
|
4e56c962f4 | ||
|
|
4ef0e43759 | ||
|
|
6945d10297 | ||
|
|
4d6cef7ac8 | ||
|
|
a7786d5ff2 | ||
|
|
6c1de975d9 | ||
|
|
a1079e455a | ||
|
|
5457c7f069 | ||
|
|
b8c1a3f96c | ||
|
|
cee8e85f76 | ||
|
|
09f166577e | ||
|
|
bcc21531fb | ||
|
|
da4eacdffe | ||
|
|
6102e560ba | ||
|
|
ff3aa57117 | ||
|
|
49db6f4fac | ||
|
|
20f6a597ab | ||
|
|
04c453721c | ||
|
|
350ffecc1f | ||
|
|
b0557aa16b | ||
|
|
1c9429a6ea | ||
|
|
206e6b1730 | ||
|
|
357cee2849 | ||
|
|
0b49997bb6 | ||
|
|
5e09dd380d | ||
|
|
c7303adb0d | ||
|
|
ed1f096a6f | ||
|
|
6ab5d28cf3 | ||
|
|
a75148cb16 | ||
|
|
f7bbc4004a | ||
|
|
cee21ca082 | ||
|
|
08ec12b391 | ||
|
|
ff5e2a9a8c | ||
|
|
e0b9b5cc6c | ||
|
|
aca4770481 | ||
|
|
5d5157fc65 | ||
|
|
fb6ef61a4d | ||
|
|
ee24ad7b13 | ||
|
|
f8e90ba3f0 | ||
|
|
ad0b70ca23 | ||
|
|
7dfa135b2c | ||
|
|
beeaa05658 | ||
|
|
6b6d654f60 | ||
|
|
853c83d0c2 | ||
|
|
1809990ed4 | ||
|
|
79d49853d2 | ||
|
|
1f608d3743 | ||
|
|
df024dd982 | ||
|
|
45da85765c | ||
|
|
bd0ad59c27 | ||
|
|
cce40acba5 | ||
|
|
bc9491ab69 | ||
|
|
b909bac0dc | ||
|
|
8618e41b32 | ||
|
|
4687f94141 | ||
|
|
440912dcff | ||
|
|
8b87a26e7e | ||
|
|
44ae93df3e | ||
|
|
8f80ba9520 | ||
|
|
2b213da967 | ||
|
|
e91e1eb9aa | ||
|
|
b24129fb3e | ||
|
|
350b1421bb | ||
|
|
f01c79a94f | ||
|
|
463f6352ce | ||
|
|
a80fe05e23 | ||
|
|
58d7833c5c | ||
|
|
5012f61599 | ||
|
|
85c33823c3 | ||
|
|
c83a112669 | ||
|
|
e04ada1319 | ||
|
|
d866dcb3d2 | ||
|
|
81ec476f3a | ||
|
|
1e6adf0a06 | ||
|
|
7d221e2518 | ||
|
|
56d3cbead0 | ||
|
|
5e8c97f1ba | ||
|
|
4687ad4ed6 | ||
|
|
994b247f8e | ||
|
|
0419f50ab0 | ||
|
|
f9f40adcdc | ||
|
|
3264d30b44 | ||
|
|
4d885653e9 | ||
|
|
475b6bef53 | ||
|
|
d39de0ad38 | ||
|
|
d14a7d756e | ||
|
|
b050c1bb8f | ||
|
|
276dfc591b | ||
|
|
b49d76ebee | ||
|
|
a6be44789b | ||
|
|
a4313c26cb | ||
|
|
d4b250d509 | ||
|
|
29743a9e02 | ||
|
|
fecb77e344 | ||
|
|
779671753d | ||
|
|
d5e152b35e | ||
|
|
270657a62c | ||
|
|
3601b9c860 | ||
|
|
c8fe12cd91 | ||
|
|
deae5fbaec | ||
|
|
5b558af2b3 | ||
|
|
4150d5306f | ||
|
|
8c2e4700f9 | ||
|
|
adaecada20 | ||
|
|
258895bcc9 | ||
|
|
2eb7c25bae | ||
|
|
2e4e9434c1 | ||
|
|
0cad204e74 | ||
|
|
0bc2edc044 | ||
|
|
16488e7db8 | ||
|
|
974841926d | ||
|
|
8db20e0d95 | ||
|
|
d00d29d6b5 | ||
|
|
dc976cd665 | ||
|
|
6d6b986a66 | ||
|
|
bffdede0fa | ||
|
|
a4c258e9ec | ||
|
|
8d837558ac | ||
|
|
e673ed08ec | ||
|
|
f0e07bff5a | ||
|
|
3ec06a1fc3 | ||
|
|
6b79e2b407 | ||
|
|
0eed9dbc44 | ||
|
|
53c7832fd1 | ||
|
|
ca1cc0e2c2 | ||
|
|
5d8728c7ef | ||
|
|
a8cec4c7e6 | ||
|
|
2b5ccdc55f | ||
|
|
d92d5b5258 | ||
|
|
a591184d2a | ||
|
|
ee881e4c78 | ||
|
|
61fbb24e36 | ||
|
|
d582949488 | ||
|
|
de574eb4d9 | ||
|
|
bfd90968f1 | ||
|
|
4a924c9b54 | ||
|
|
0453d60c64 | ||
|
|
c4f4f8b1b8 | ||
|
|
3e80eaa342 | ||
|
|
00a0cb3403 | ||
|
|
ea93cad5ff | ||
|
|
4453a0d20d | ||
|
|
1e837e3c9d | ||
|
|
0f95f7cea3 | ||
|
|
0b0068ab86 | ||
|
|
31c7fa833e | ||
|
|
db16ca0079 | ||
|
|
a824f47bc6 | ||
|
|
99392debe8 | ||
|
|
0cc739afc8 | ||
|
|
0ab62b0343 | ||
|
|
75d25dd5cc | ||
|
|
2e54da13d8 | ||
|
|
f34f416bf5 | ||
|
|
021c63891d | ||
|
|
a968862e6b | ||
|
|
a08189d457 | ||
|
|
0a936696c3 | ||
|
|
55e33eaf4c | ||
|
|
3da5fb223f | ||
|
|
a3c5a664e5 | ||
|
|
b638fb2f30 | ||
|
|
c1b10b2222 | ||
|
|
bee29714d9 | ||
|
|
d40d5276dd | ||
|
|
568f0aad71 | ||
|
|
38474fa9d4 | ||
|
|
f7f974a28b | ||
|
|
3c150b384c | ||
|
|
65816049ba | ||
|
|
c1c881ded5 | ||
|
|
82c4dd8b86 | ||
|
|
711d09a107 | ||
|
|
74013b6611 | ||
|
|
790f399986 | ||
|
|
73cdd36594 | ||
|
|
50ac3eb28d | ||
|
|
d753cff91a | ||
|
|
89f1909e4b | ||
|
|
37916a22ad | ||
|
|
76e5d0595d | ||
|
|
f03cb8f134 | ||
|
|
c2a0e8afc3 | ||
|
|
31a904b903 | ||
|
|
c174cab3ee | ||
|
|
fe12938c23 | ||
|
|
4fa5c963a1 | ||
|
|
48ce256ba2 | ||
|
|
8cb2fa8600 | ||
|
|
8f460b92f1 | ||
|
|
d99a08a441 | ||
|
|
7555b1f876 | ||
|
|
a537231f19 | ||
|
|
8044d1b840 | ||
|
|
2b58ce4ae4 | ||
|
|
ef605cd76c | ||
|
|
a84b5b168f | ||
|
|
16f6ee04d0 | ||
|
|
44be057aa3 | ||
|
|
422f6967b2 | ||
|
|
4528cc8ba6 | ||
|
|
87e91ebc1d | ||
|
|
fd00d111ea | ||
|
|
b8dc9000bd | ||
|
|
58c1066765 | ||
|
|
37096a697b | ||
|
|
17d0920186 | ||
|
|
1e05538364 | ||
|
|
cf28617cd6 | ||
|
|
d0d8640711 | ||
|
|
e6158d1874 | ||
|
|
2e9d1ea8a3 | ||
|
|
59b0153236 | ||
|
|
9f8ff912c4 | ||
|
|
f0e4a2124a | ||
|
|
11ab5c7d56 | ||
|
|
3f334d9e5e | ||
|
|
ff891b1ff2 | ||
|
|
2914ee10b0 | ||
|
|
e29c2fb782 | ||
|
|
b763f1809e | ||
|
|
d26b44104a | ||
|
|
b73fd2a6d2 | ||
|
|
f258aba6d1 | ||
|
|
2e70848aa0 | ||
|
|
e973aeef0d | ||
|
|
50e1ac731d | ||
|
|
43addc1548 | ||
|
|
4901911c1a | ||
|
|
44a653925a | ||
|
|
94a07a8da7 | ||
|
|
ad41afe65e | ||
|
|
77fa7519c4 | ||
|
|
6e29148d4d | ||
|
|
3044f3bfe5 | ||
|
|
67a8627cf6 | ||
|
|
3fb433cb91 | ||
|
|
5f498e10bd | ||
|
|
fdad62e88b | ||
|
|
955c81acef | ||
|
|
e1058f3416 | ||
|
|
edf16a253d | ||
|
|
46f5ef4100 | ||
|
|
b843255236 | ||
|
|
3a968e5072 | ||
|
|
b164330e3c | ||
|
|
69433c9f68 | ||
|
|
bd8ffd36bf | ||
|
|
fd80e84ea6 | ||
|
|
4824237a98 | ||
|
|
2c9a05eb59 | ||
|
|
ecb5bdaf7e | ||
|
|
2feeb1f44c | ||
|
|
554f353773 | ||
|
|
f6cdff2c5b | ||
|
|
aee27e94c9 | ||
|
|
695893e1ac | ||
|
|
b800a8eb2e | ||
|
|
9749ef34b5 | ||
|
|
9a43362127 | ||
|
|
866024ea6c | ||
|
|
601cc1f92c | ||
|
|
d6a9a4464d | ||
|
|
dac271725a | ||
|
|
e1fbecfcf7 | ||
|
|
63d10027a4 | ||
|
|
ef0773b8a3 | ||
|
|
3daaddf15b | ||
|
|
570c3fe690 | ||
|
|
cbd1a7263a | ||
|
|
7fc5fbd4ce | ||
|
|
6f6de402ad | ||
|
|
2ec4f5af10 | ||
|
|
281662a6e1 | ||
|
|
2edd032ec7 | ||
|
|
50eb02f68b | ||
|
|
d73f3adc43 | ||
|
|
116107f464 | ||
|
|
da44bb1707 | ||
|
|
f43aed677e | ||
|
|
0d051aaae2 | ||
|
|
e4e48ff995 | ||
|
|
442a6bffa4 | ||
|
|
aab262d991 | ||
|
|
47b9910b48 | ||
|
|
0b0e6fe448 | ||
|
|
23d65e7162 | ||
|
|
024fd54d0b | ||
|
|
c44c19e911 | ||
|
|
c132dbdefa | ||
|
|
f3081e7013 | ||
|
|
f904f14f9e | ||
|
|
8917a6d99b | ||
|
|
5a4765046e | ||
|
|
d923d1d66b | ||
|
|
1f2c1e14db | ||
|
|
07e3a0ec15 | ||
|
|
427db7c7e2 | ||
|
|
dad3a7f263 | ||
|
|
5bd0bb637f | ||
|
|
f05095770c | ||
|
|
de189f2db6 | ||
|
|
cee159dfa3 | ||
|
|
4463124bdd | ||
|
|
34402cc46a | ||
|
|
54d9833db0 | ||
|
|
5fe8cb56fc | ||
|
|
7919d81fb1 | ||
|
|
9d80b28a4f | ||
|
|
1fcd91bcc5 | ||
|
|
e456e2e63a | ||
|
|
ee41b99049 | ||
|
|
111d674e71 | ||
|
|
8f048cfbd9 | ||
|
|
cd1b350dae | ||
|
|
8334757af9 | ||
|
|
7103ac6a32 | ||
|
|
f6b131e706 | ||
|
|
d1b2b99226 | ||
|
|
e356f2511b | ||
|
|
e5f8b22a43 | ||
|
|
45b84fb4bb | ||
|
|
f022c89249 | ||
|
|
ab05144716 | ||
|
|
aeb4914e67 | ||
|
|
76bcd4d44f | ||
|
|
50f5e1bc83 | ||
|
|
4c339dd4b0 | ||
|
|
bc2b9500e3 | ||
|
|
32857d81c5 | ||
|
|
7268131f57 | ||
|
|
85b020f76c | ||
|
|
a7833cc9a9 | ||
|
|
28f75d80d5 | ||
|
|
919294e977 | ||
|
|
b917ffa4d7 | ||
|
|
d44151d6ff | ||
|
|
7640acfb1f | ||
|
|
aed9ecef2a | ||
|
|
18cddd7972 | ||
|
|
e6b25f4ae3 | ||
|
|
d1c0050e65 | ||
|
|
ecdfa136a0 | ||
|
|
5cd513ee63 | ||
|
|
ab45086546 | ||
|
|
77ba7359f4 | ||
|
|
8cbe2e14d9 | ||
|
|
f682fb8040 | ||
|
|
ee86eedf01 | ||
|
|
1f89cf3343 | ||
|
|
c4e6511a59 | ||
|
|
44843be4c8 | ||
|
|
054e963bef | ||
|
|
afb66a7884 | ||
|
|
b9df9e26f2 | ||
|
|
25ae36ceb5 | ||
|
|
3ae8daedaa | ||
|
|
e11c1d66ab | ||
|
|
b913e1e11e | ||
|
|
3c4b6d5735 | ||
|
|
e6123eac19 | ||
|
|
30ca25897e | ||
|
|
abaee6b9ed | ||
|
|
4d7c9e1ab7 | ||
|
|
cc5687f26c | ||
|
|
cdb3616dca | ||
|
|
78e76f26f9 | ||
|
|
9a7580dedd | ||
|
|
dc2da8cff4 | ||
|
|
019a9f0329 | ||
|
|
fe5d9ad171 | ||
|
|
dbc0093b31 | ||
|
|
92e512b8b6 | ||
|
|
abe4dc8ac1 | ||
|
|
dc14701d20 | ||
|
|
737e0f3085 | ||
|
|
81b7ea4362 | ||
|
|
09dfde0ba1 | ||
|
|
3ba7e966b5 | ||
|
|
a1cd4834d1 | ||
|
|
a724038dc6 | ||
|
|
4221cf7731 | ||
|
|
c34ac91ff0 | ||
|
|
5fe38f7c88 | ||
|
|
bd7e515290 | ||
|
|
076fac07eb | ||
|
|
9348161600 | ||
|
|
dac3c158a5 | ||
|
|
17d8bbf330 | ||
|
|
9344687a56 | ||
|
|
cf534d735c | ||
|
|
501924bc60 | ||
|
|
d117251747 | ||
|
|
6ea61a8486 | ||
|
|
e4d903af20 | ||
|
|
2d9797da35 | ||
|
|
07ea806553 | ||
|
|
5ac0316c62 | ||
|
|
9536ba22af | ||
|
|
5503749085 | ||
|
|
9bfe2fa371 | ||
|
|
d8ce6e4426 | ||
|
|
43d2d6d98c | ||
|
|
64c233efd4 | ||
|
|
2245a4e117 | ||
|
|
9ceec40b76 | ||
|
|
0f13b90059 | ||
|
|
d91fc16ae4 | ||
|
|
bc01a96f9d | ||
|
|
85b2822f5e | ||
|
|
c33d8694bb | ||
|
|
685bd027f0 | ||
|
|
f592d620d5 | ||
|
|
2b127b73ac | ||
|
|
8855902cfe | ||
|
|
9d8ddc6a08 | ||
|
|
4ca5189e73 | ||
|
|
873597cb84 | ||
|
|
44d742f232 | ||
|
|
6e7dbf99f3 | ||
|
|
1ba1076888 | ||
|
|
cafa108f69 | ||
|
|
deeff36e16 | ||
|
|
d770b14358 | ||
|
|
20414ba4ad | ||
|
|
92721a1d45 | ||
|
|
f329fddab9 | ||
|
|
f2efde27f6 | ||
|
|
02c58f22be | ||
|
|
f751dcd245 | ||
|
|
a97107bd90 | ||
|
|
b2ce45a417 | ||
|
|
4e0b5d85ba | ||
|
|
a958ae5e29 | ||
|
|
4d50fbf8dc | ||
|
|
485f6e5954 | ||
|
|
1f6ce838ba | ||
|
|
0dc5773849 | ||
|
|
bc347f749c | ||
|
|
1b215059e7 | ||
|
|
db079a2733 | ||
|
|
26f71d3536 | ||
|
|
eb7ae2588c | ||
|
|
278c14ba2e | ||
|
|
74e83dda54 | ||
|
|
28c1fca477 | ||
|
|
1f0324102a | ||
|
|
a782ad092d | ||
|
|
eae4eb419a | ||
|
|
fb7f38f46e | ||
|
|
93d0cae455 | ||
|
|
35f6b5d562 | ||
|
|
2aefa06ef1 | ||
|
|
5906888477 | ||
|
|
f22c7d0da6 | ||
|
|
93b38707b2 | ||
|
|
6ecf53078f | ||
|
|
9c93b7cb59 | ||
|
|
7789e8319c | ||
|
|
7d7a28beb3 | ||
|
|
27a113d872 | ||
|
|
67f8f222d9 | ||
|
|
5347c12fed | ||
|
|
b194180f76 | ||
|
|
fb30b7d17a | ||
|
|
c341dcaa3d | ||
|
|
b695a2574b | ||
|
|
aa68a326c8 | ||
|
|
c2922d5991 | ||
|
|
85888030c3 | ||
|
|
7cf59c1e60 | ||
|
|
9738b0ff69 | ||
|
|
3021c78390 | ||
|
|
6eeaf8d9fb | ||
|
|
fa9afec0c2 | ||
|
|
d6862bf8c1 | ||
|
|
de01c38bbe | ||
|
|
7e811908e0 | ||
|
|
5f59f24f92 | ||
|
|
e414fcf3fb | ||
|
|
079ad8f35a | ||
|
|
a4d7e0c78e | ||
|
|
e9c2f173c5 | ||
|
|
44f489d581 | ||
|
|
cb48bbd806 | ||
|
|
0a761d7c43 | ||
|
|
a0f47aa72e | ||
|
|
f9abc6fc85 | ||
|
|
d840c597b5 | ||
|
|
3ca654d256 | ||
|
|
e0e01f6c50 | ||
|
|
d9dab1b6c7 | ||
|
|
3b2ef6e1a8 | ||
|
|
c125a3871a | ||
|
|
0996bd5acf | ||
|
|
ea77d557da | ||
|
|
1b01161ea4 | ||
|
|
2230cb9562 | ||
|
|
9e0c7c46a2 | ||
|
|
be305588d3 | ||
|
|
9f994df814 | ||
|
|
3062580006 | ||
|
|
596ba754b1 | ||
|
|
b980e563b9 | ||
|
|
7fe2606cb3 | ||
|
|
0c3b1fe3c4 | ||
|
|
c9ee2e351c | ||
|
|
e3aef20f42 | ||
|
|
60614badaf | ||
|
|
288cee9611 | ||
|
|
24aca37538 | ||
|
|
b853ceea65 | ||
|
|
3ee2798ede | ||
|
|
5c5106c14a | ||
|
|
c367b21c71 | ||
|
|
2eef6df66a | ||
|
|
300aa8d86c | ||
|
|
727f1638d7 | ||
|
|
ee6df5852a | ||
|
|
90525b1c43 | ||
|
|
bbb95dbc5b | ||
|
|
f4b7f80d59 | ||
|
|
220f7373c8 | ||
|
|
4bb5785f29 | ||
|
|
f9a7a7d161 | ||
|
|
de94c780d9 | ||
|
|
0b9230380c | ||
|
|
209a55b681 | ||
|
|
dc2f69f5d1 | ||
|
|
ad2f1b7b36 | ||
|
|
dd2d96a50f | ||
|
|
2bff28e305 | ||
|
|
d68234d879 | ||
|
|
b3babf26a5 | ||
|
|
ecca0eff31 | ||
|
|
28677f9621 | ||
|
|
caecfadf11 | ||
|
|
5cf8e3aa53 | ||
|
|
76cf2c61db | ||
|
|
b4d976f2db | ||
|
|
777d127c74 | ||
|
|
0678803803 | ||
|
|
d2fbc9f5e3 | ||
|
|
d81088dff7 | ||
|
|
1aaad9336f | ||
|
|
1f3c024d9d | ||
|
|
74a480f94e | ||
|
|
c6e8d3269c | ||
|
|
dcb5a3a740 | ||
|
|
c0ef546b02 | ||
|
|
7a78a83651 | ||
|
|
10cbf99310 | ||
|
|
b63aefcda9 | ||
|
|
6a77634b34 | ||
|
|
8ca91b1774 | ||
|
|
1c9d9e79d5 | ||
|
|
3aa1ee1218 | ||
|
|
06aa5a8120 | ||
|
|
580f9ecded | ||
|
|
270032670a | ||
|
|
4f056cdb55 | ||
|
|
c14241436b | ||
|
|
50b56d6088 | ||
|
|
8ec2ae7954 | ||
|
|
40d82b29cf | ||
|
|
0b953d98f5 | ||
|
|
8833d76709 | ||
|
|
027b316fd2 | ||
|
|
d612f11c11 | ||
|
|
250b0ab182 | ||
|
|
675dd12b6c | ||
|
|
7e76eea059 | ||
|
|
f45483e519 | ||
|
|
65047bf976 | ||
|
|
d586a82a53 | ||
|
|
28709961e9 | ||
|
|
e9f237f39d | ||
|
|
4156bfd810 | ||
|
|
fe75b95464 | ||
|
|
95954188b2 | ||
|
|
63f59201f8 | ||
|
|
370e8281b3 | ||
|
|
685df33584 | ||
|
|
4332c9c7a6 | ||
|
|
4a00f1cc74 | ||
|
|
7ff77504cb | ||
|
|
0d1854e44a | ||
|
|
fe6858f2d9 | ||
|
|
12c7db3a16 | ||
|
|
3ecdec02bf | ||
|
|
d6c24d59b0 | ||
|
|
bb3d1bb6cb | ||
|
|
14c8738a71 | ||
|
|
1a829bb998 | ||
|
|
9d339e94f2 | ||
|
|
ad7b1fa6fb | ||
|
|
42355b70c2 | ||
|
|
faa2558e2f | ||
|
|
081397737b | ||
|
|
55d36eaf4f | ||
|
|
26cd1728ac | ||
|
|
a0065da4a4 | ||
|
|
c11e823ff3 | ||
|
|
197e50a298 | ||
|
|
507e12520e | ||
|
|
2cc04de397 | ||
|
|
f4150a7829 | ||
|
|
5418bd3b24 | ||
|
|
76d5fa4694 | ||
|
|
386dda8233 | ||
|
|
8076c1697c | ||
|
|
65fc9a6e0e | ||
|
|
cde0b6ae8d | ||
|
|
b12760b976 | ||
|
|
b679a6ba37 | ||
|
|
2f5f08c35d | ||
|
|
8f48c14ed4 | ||
|
|
5d37fa6e36 | ||
|
|
f51581bd1b | ||
|
|
50ca6b6ffc | ||
|
|
63b9ec4c5e | ||
|
|
b115bc4247 | ||
|
|
dadc30f795 | ||
|
|
111d8391e2 | ||
|
|
1157b454b2 | ||
|
|
8a6473610b | ||
|
|
ea7911be89 | ||
|
|
9ee648e0c3 | ||
|
|
543682fd3b | ||
|
|
88cb63e4a1 | ||
|
|
76212d1cca | ||
|
|
a8df9e5122 | ||
|
|
2db180d909 | ||
|
|
b716fe8f06 | ||
|
|
69e2dc0404 | ||
|
|
a38b75572f | ||
|
|
e18de761b6 | ||
|
|
816ea39827 | ||
|
|
1cd4cdd0e5 | ||
|
|
768e969c90 | ||
|
|
57db66634d | ||
|
|
87789c1de8 | ||
|
|
c3c1511ec6 | ||
|
|
6b41127421 | ||
|
|
d232a439f7 | ||
|
|
c04f21e83e | ||
|
|
8762069b37 | ||
|
|
d9ebdd2684 | ||
|
|
3e4c10ef9c | ||
|
|
17eb2ca5a2 | ||
|
|
63725d7534 | ||
|
|
00f30ea457 | ||
|
|
1b2a3c7144 | ||
|
|
01a1777370 | ||
|
|
32945c7f45 | ||
|
|
b0b8846430 | ||
|
|
fdb146a43a | ||
|
|
42c1f1fc9d | ||
|
|
89a8ef86b5 | ||
|
|
f0fb767f57 | ||
|
|
4bd93464bf | ||
|
|
3d3de82ca9 | ||
|
|
c3ff9e6be8 | ||
|
|
21f79e5919 | ||
|
|
0342e25c74 | ||
|
|
91f982fb0b | ||
|
|
b9ab43a4bb | ||
|
|
6e0e48bf8a | ||
|
|
dcc8313dbf | ||
|
|
bf5831faa3 | ||
|
|
5eff035f55 | ||
|
|
7c60068388 | ||
|
|
d843fb078a | ||
|
|
41b2e4633f | ||
|
|
57144ac0cf | ||
|
|
a305b6adbf | ||
|
|
94daaa4abf | ||
|
|
901337186d | ||
|
|
7e2f64f60b | ||
|
|
126cba2324 | ||
|
|
2f9dcd7906 | ||
|
|
e537b5d8e1 | ||
|
|
e0e70c9222 | ||
|
|
1b21e5df54 | ||
|
|
4b76af37ae | ||
|
|
486c445afb | ||
|
|
4547c48013 | ||
|
|
8f21201c91 | ||
|
|
532b74a206 | ||
|
|
0b184913b9 | ||
|
|
97719e40e4 | ||
|
|
5ad3062b66 | ||
|
|
92d012a92d | ||
|
|
fc187f263e | ||
|
|
fd94f85abe | ||
|
|
4e9e1b660d | ||
|
|
d01adedff5 | ||
|
|
c247f430f7 | ||
|
|
3d6a358042 | ||
|
|
4d1dcd11de | ||
|
|
b33655b0d6 | ||
|
|
81dee04dc9 | ||
|
|
114018e3e6 | ||
|
|
ef8cf83b28 | ||
|
|
633857b0e3 | ||
|
|
214574d11f | ||
|
|
8584665ade | ||
|
|
516c56d0c5 | ||
|
|
5891b43ce2 | ||
|
|
62e75f95aa | ||
|
|
b07621e27e | ||
|
|
545d8968fd | ||
|
|
7cf2f58513 | ||
|
|
618e3e5e91 | ||
|
|
c703b60986 | ||
|
|
7c0ce5c282 | ||
|
|
82fe34b1f7 | ||
|
|
65f9aae81d | ||
|
|
2d9fac23e7 | ||
|
|
ebc4b52f41 | ||
|
|
c4e6d4b348 | ||
|
|
eab32bce6c | ||
|
|
55d2094094 | ||
|
|
a0d50a2b23 | ||
|
|
9efeb1b2ec | ||
|
|
86e2cb0428 | ||
|
|
53c2c0f91d | ||
|
|
bdc7b8b75a | ||
|
|
1bfdd54810 | ||
|
|
b4bf6c12a5 | ||
|
|
ab35c241c2 | ||
|
|
b3dccfaeb6 | ||
|
|
6477e31c1e | ||
|
|
dd4a1c998b | ||
|
|
70203e6e5a | ||
|
|
d778a7c5ca | ||
|
|
f8e59636cd | ||
|
|
2d1a0b0a05 | ||
|
|
c9b2234d90 | ||
|
|
82b224539b | ||
|
|
0b15ffb95b | ||
|
|
ce9aaab22f | ||
|
|
3f53f1186d | ||
|
|
c0aff396d2 | ||
|
|
955900507f | ||
|
|
d606abc544 | ||
|
|
44400d2a66 | ||
|
|
60a98cacef | ||
|
|
6a990565ff | ||
|
|
3f0b0f3250 | ||
|
|
1a7371ea17 | ||
|
|
850d1ee984 | ||
|
|
2c7928b163 | ||
|
|
87d1ec6a4c | ||
|
|
53c62537f7 | ||
|
|
418d93fdfd | ||
|
|
f2ce2f1778 | ||
|
|
5b6c61fc75 | ||
|
|
1d77581d96 | ||
|
|
3b921cf393 | ||
|
|
d334f7f1f6 | ||
|
|
8c9764476c | ||
|
|
b7d5a3e0b5 | ||
|
|
e0405031a7 | ||
|
|
ee24b686b3 | ||
|
|
835eb14c79 | ||
|
|
9aadf7abc1 | ||
|
|
243f9e8377 | ||
|
|
6e0c6d9cc9 | ||
|
|
a3076cf951 | ||
|
|
6696882c71 | ||
|
|
17b039e85d | ||
|
|
81539e6ab4 | ||
|
|
92304b9f8a | ||
|
|
ec1de5ae8b | ||
|
|
49198a61ef | ||
|
|
c22d529528 | ||
|
|
8c5773abc1 | ||
|
|
01f8c37bd3 | ||
|
|
b7718985d5 | ||
|
|
90cda11868 | ||
|
|
5cb877e096 |
@@ -1,6 +0,0 @@
|
||||
[run]
|
||||
omit='.env/*'
|
||||
source='.'
|
||||
|
||||
[report]
|
||||
show_missing = true
|
||||
@@ -4,22 +4,22 @@
|
||||
!ldm
|
||||
!pyproject.toml
|
||||
|
||||
# Guard against pulling in any models that might exist in the directory tree
|
||||
**/*.pt*
|
||||
**/*.ckpt
|
||||
|
||||
# ignore frontend but whitelist dist
|
||||
invokeai/frontend/
|
||||
!invokeai/frontend/dist/
|
||||
# ignore frontend/web but whitelist dist
|
||||
invokeai/frontend/web/
|
||||
!invokeai/frontend/web/dist/
|
||||
|
||||
# ignore invokeai/assets but whitelist invokeai/assets/web
|
||||
invokeai/assets/
|
||||
!invokeai/assets/web/
|
||||
|
||||
# Guard against pulling in any models that might exist in the directory tree
|
||||
**/*.pt*
|
||||
**/*.ckpt
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
**/__pycache__/
|
||||
**/*.py[cod]
|
||||
|
||||
# Distribution / packaging
|
||||
*.egg-info/
|
||||
*.egg
|
||||
**/*.egg-info/
|
||||
**/*.egg
|
||||
|
||||
1
.git-blame-ignore-revs
Normal file
1
.git-blame-ignore-revs
Normal file
@@ -0,0 +1 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
59
.github/CODEOWNERS
vendored
59
.github/CODEOWNERS
vendored
@@ -1,51 +1,34 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
/docs/ @lstein @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @ebr @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein @blessedcoolant
|
||||
/installer/ @ebr @lstein @tildebyte @blessedcoolant
|
||||
ldm/invoke/config @lstein @ebr @blessedcoolant
|
||||
invokeai/assets @lstein @ebr @blessedcoolant
|
||||
invokeai/configs @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/_version.py @lstein @blessedcoolant
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
/invokeai/configs @lstein
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation and model management
|
||||
/ldm/*.py @lstein @blessedcoolant
|
||||
/ldm/generate.py @lstein @keturn @blessedcoolant
|
||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||
/ldm/invoke/generator @keturn @damian0815 @blessedcoolant
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||
/ldm/invoke/txt2mask.py @lstein @blessedcoolant
|
||||
/ldm/invoke/patchmatch.py @Kyle0654 @blessedcoolant @lstein
|
||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||
|
||||
# attention, textual inversion, model configuration
|
||||
/ldm/models @damian0815 @keturn @lstein @blessedcoolant
|
||||
/ldm/modules @damian0815 @keturn @lstein @blessedcoolant
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
|
||||
# Nodes
|
||||
apps/ @Kyle0654 @lstein @blessedcoolant
|
||||
|
||||
# legacy REST API
|
||||
# is CapableWeb still engaged?
|
||||
/ldm/invoke/pngwriter.py @CapableWeb @lstein @blessedcoolant
|
||||
/ldm/invoke/server_legacy.py @CapableWeb @lstein @blessedcoolant
|
||||
/scripts/legacy_api.py @CapableWeb @lstein @blessedcoolant
|
||||
/tests/legacy_tests.sh @CapableWeb @lstein @blessedcoolant
|
||||
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
10
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@@ -65,6 +65,16 @@ body:
|
||||
placeholder: 8GB
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: version-number
|
||||
attributes:
|
||||
label: What version did you experience this issue on?
|
||||
description: |
|
||||
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: X.X.X
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
|
||||
19
.github/stale.yaml
vendored
Normal file
19
.github/stale.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
# Number of days of inactivity before an issue becomes stale
|
||||
daysUntilStale: 28
|
||||
# Number of days of inactivity before a stale issue is closed
|
||||
daysUntilClose: 14
|
||||
# Issues with these labels will never be considered stale
|
||||
exemptLabels:
|
||||
- pinned
|
||||
- security
|
||||
# Label to use when marking an issue as stale
|
||||
staleLabel: stale
|
||||
# Comment to post when marking an issue as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Please
|
||||
update the ticket if this is still a problem on the latest release.
|
||||
# Comment to post when closing a stale issue. Set to `false` to disable
|
||||
closeComment: >
|
||||
Due to inactivity, this issue has been automatically closed. If this is
|
||||
still a problem on the latest release, please recreate the issue.
|
||||
23
.github/workflows/build-container.yml
vendored
23
.github/workflows/build-container.yml
vendored
@@ -5,17 +5,20 @@ on:
|
||||
- 'main'
|
||||
- 'update/ci/docker/*'
|
||||
- 'update/docker/*'
|
||||
- 'dev/ci/docker/*'
|
||||
- 'dev/docker/*'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
- '.dockerignore'
|
||||
- 'invokeai/**'
|
||||
- 'docker/Dockerfile'
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
@@ -24,11 +27,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flavor:
|
||||
- amd
|
||||
- rocm
|
||||
- cuda
|
||||
- cpu
|
||||
include:
|
||||
- flavor: amd
|
||||
- flavor: rocm
|
||||
pip-extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
- flavor: cuda
|
||||
pip-extra-index-url: ''
|
||||
@@ -54,9 +57,9 @@ jobs:
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=pep440,pattern={{version}}
|
||||
type=pep440,pattern={{major}}.{{minor}}
|
||||
type=pep440,pattern={{major}}
|
||||
type=sha,enable=true,prefix=sha-,format=short
|
||||
flavor: |
|
||||
latest=${{ matrix.flavor == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
@@ -92,7 +95,7 @@ jobs:
|
||||
context: .
|
||||
file: ${{ env.DOCKERFILE }}
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: PIP_EXTRA_INDEX_URL=${{ matrix.pip-extra-index-url }}
|
||||
|
||||
27
.github/workflows/close-inactive-issues.yml
vendored
Normal file
27
.github/workflows/close-inactive-issues.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "00 6 * * *"
|
||||
|
||||
env:
|
||||
DAYS_BEFORE_ISSUE_STALE: 14
|
||||
DAYS_BEFORE_ISSUE_CLOSE: 28
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
|
||||
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
|
||||
stale-issue-label: "Inactive Issue"
|
||||
stale-issue-message: "There has been no activity in this issue for ${{ env.DAYS_BEFORE_ISSUE_STALE }} days. If this issue is still being experienced, please reply with an updated confirmation that the issue is still being experienced with the latest release."
|
||||
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 500
|
||||
22
.github/workflows/lint-frontend.yml
vendored
22
.github/workflows/lint-frontend.yml
vendored
@@ -3,14 +3,22 @@ name: Lint frontend
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'invokeai/frontend/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'invokeai/frontend/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
lint-frontend:
|
||||
@@ -23,7 +31,7 @@ jobs:
|
||||
node-version: '18'
|
||||
- uses: actions/checkout@v3
|
||||
- run: 'yarn install --frozen-lockfile'
|
||||
- run: 'yarn tsc'
|
||||
- run: 'yarn run madge'
|
||||
- run: 'yarn run lint --max-warnings=0'
|
||||
- run: 'yarn run prettier --check'
|
||||
- run: 'yarn run lint:tsc'
|
||||
- run: 'yarn run lint:madge'
|
||||
- run: 'yarn run lint:eslint'
|
||||
- run: 'yarn run lint:prettier'
|
||||
|
||||
18
.github/workflows/mkdocs-material.yml
vendored
18
.github/workflows/mkdocs-material.yml
vendored
@@ -2,13 +2,19 @@ name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'refs/heads/v2.3'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
@@ -19,11 +25,15 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install -r docs/requirements-mkdocs.txt
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
@@ -33,7 +43,7 @@ jobs:
|
||||
--verbose
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
||||
2
.github/workflows/pypi-release.yml
vendored
2
.github/workflows/pypi-release.yml
vendored
@@ -3,7 +3,7 @@ name: PyPI Release
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'ldm/invoke/_version.py'
|
||||
- 'invokeai/version/invokeai_version.py'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
11
.github/workflows/test-invoke-pip-skip.yml
vendored
11
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -1,12 +1,11 @@
|
||||
name: Test invoke.py pip
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
paths:
|
||||
- '**'
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
14
.github/workflows/test-invoke-pip.yml
vendored
14
.github/workflows/test-invoke-pip.yml
vendored
@@ -5,17 +5,13 @@ on:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'ldm/**'
|
||||
- 'invokeai/backend/**'
|
||||
- 'invokeai/configs/**'
|
||||
- 'invokeai/frontend/dist/**'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -112,7 +108,7 @@ jobs:
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from ldm.invoke.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
invokeai.init
|
||||
.version
|
||||
.last_model
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
@@ -63,6 +65,7 @@ pip-delete-this-directory.txt
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coveragerc
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
@@ -73,6 +76,7 @@ cov.xml
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
.pytest.ini
|
||||
cover/
|
||||
junit/
|
||||
|
||||
@@ -198,7 +202,7 @@ checkpoints
|
||||
.DS_Store
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/*
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
# Scratch folder
|
||||
.scratch/
|
||||
@@ -213,11 +217,6 @@ gfpgan/
|
||||
# config file (will be created by installer)
|
||||
configs/models.yaml
|
||||
|
||||
# weights (will be created by installer)
|
||||
models/ldm/stable-diffusion-v1/*.ckpt
|
||||
models/clipseg
|
||||
models/gfpgan
|
||||
|
||||
# ignore initfile
|
||||
.invokeai
|
||||
|
||||
@@ -232,6 +231,3 @@ installer/install.bat
|
||||
installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
|
||||
# no longer stored in source directory
|
||||
models
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
[pytest]
|
||||
DJANGO_SETTINGS_MODULE = webtas.settings
|
||||
; python_files = tests.py test_*.py *_tests.py
|
||||
|
||||
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml
|
||||
13
README.md
13
README.md
@@ -33,6 +33,8 @@
|
||||
|
||||
</div>
|
||||
|
||||
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
|
||||
|
||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||
@@ -84,7 +86,7 @@ installing lots of models.
|
||||
|
||||
6. Wait while the installer does its thing. After installing the software,
|
||||
the installer will launch a script that lets you configure InvokeAI and
|
||||
select a set of starting image generaiton models.
|
||||
select a set of starting image generation models.
|
||||
|
||||
7. Find the folder that InvokeAI was installed into (it is not the
|
||||
same as the unpacked zip file directory!) The default location of this
|
||||
@@ -139,15 +141,20 @@ not supported.
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
|
||||
```sh
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
```terminal
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
|
||||
```sh
|
||||
|
||||
4
coverage/.gitignore
vendored
Normal file
4
coverage/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
||||
@@ -4,15 +4,15 @@ ARG PYTHON_VERSION=3.9
|
||||
##################
|
||||
## base image ##
|
||||
##################
|
||||
FROM python:${PYTHON_VERSION}-slim AS python-base
|
||||
FROM --platform=${TARGETPLATFORM} python:${PYTHON_VERSION}-slim AS python-base
|
||||
|
||||
LABEL org.opencontainers.image.authors="mauwii@outlook.de"
|
||||
|
||||
# prepare for buildkit cache
|
||||
# Prepare apt for buildkit cache
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean \
|
||||
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' >/etc/apt/apt.conf.d/keep-cache
|
||||
|
||||
# Install necessary packages
|
||||
# Install dependencies
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
@@ -23,7 +23,7 @@ RUN \
|
||||
libglib2.0-0=2.66.* \
|
||||
libopencv-dev=4.5.*
|
||||
|
||||
# set working directory and env
|
||||
# Set working directory and env
|
||||
ARG APPDIR=/usr/src
|
||||
ARG APPNAME=InvokeAI
|
||||
WORKDIR ${APPDIR}
|
||||
@@ -32,7 +32,7 @@ ENV PATH ${APPDIR}/${APPNAME}/bin:$PATH
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
# Turns off buffering for easier container logging
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
# don't fall back to legacy build system
|
||||
# Don't fall back to legacy build system
|
||||
ENV PIP_USE_PEP517=1
|
||||
|
||||
#######################
|
||||
@@ -40,7 +40,7 @@ ENV PIP_USE_PEP517=1
|
||||
#######################
|
||||
FROM python-base AS pyproject-builder
|
||||
|
||||
# Install dependencies
|
||||
# Install build dependencies
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
@@ -51,26 +51,30 @@ RUN \
|
||||
gcc=4:10.2.* \
|
||||
python3-dev=3.9.*
|
||||
|
||||
# prepare pip for buildkit cache
|
||||
# Prepare pip for buildkit cache
|
||||
ARG PIP_CACHE_DIR=/var/cache/buildkit/pip
|
||||
ENV PIP_CACHE_DIR ${PIP_CACHE_DIR}
|
||||
RUN mkdir -p ${PIP_CACHE_DIR}
|
||||
|
||||
# create virtual environment
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
||||
# Create virtual environment
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
python3 -m venv "${APPNAME}" \
|
||||
--upgrade-deps
|
||||
|
||||
# copy sources
|
||||
COPY --link . .
|
||||
|
||||
# install pyproject.toml
|
||||
# Install requirements
|
||||
COPY --link pyproject.toml .
|
||||
COPY --link invokeai/version/invokeai_version.py invokeai/version/__init__.py invokeai/version/
|
||||
ARG PIP_EXTRA_INDEX_URL
|
||||
ENV PIP_EXTRA_INDEX_URL ${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
"${APPNAME}"/bin/pip install .
|
||||
|
||||
# Install pyproject.toml
|
||||
COPY --link . .
|
||||
RUN --mount=type=cache,target=${PIP_CACHE_DIR} \
|
||||
"${APPNAME}/bin/pip" install .
|
||||
|
||||
# build patchmatch
|
||||
# Build patchmatch
|
||||
RUN python3 -c "from patchmatch import patch_match"
|
||||
|
||||
#####################
|
||||
@@ -86,14 +90,14 @@ RUN useradd \
|
||||
-U \
|
||||
"${UNAME}"
|
||||
|
||||
# create volume directory
|
||||
# Create volume directory
|
||||
ARG VOLUME_DIR=/data
|
||||
RUN mkdir -p "${VOLUME_DIR}" \
|
||||
&& chown -R "${UNAME}" "${VOLUME_DIR}"
|
||||
&& chown -hR "${UNAME}:${UNAME}" "${VOLUME_DIR}"
|
||||
|
||||
# setup runtime environment
|
||||
USER ${UNAME}
|
||||
COPY --chown=${UNAME} --from=pyproject-builder ${APPDIR}/${APPNAME} ${APPNAME}
|
||||
# Setup runtime environment
|
||||
USER ${UNAME}:${UNAME}
|
||||
COPY --chown=${UNAME}:${UNAME} --from=pyproject-builder ${APPDIR}/${APPNAME} ${APPNAME}
|
||||
ENV INVOKEAI_ROOT ${VOLUME_DIR}
|
||||
ENV TRANSFORMERS_CACHE ${VOLUME_DIR}/.cache
|
||||
ENV INVOKE_MODEL_RECONFIGURE "--yes --default_only"
|
||||
|
||||
@@ -41,7 +41,7 @@ else
|
||||
fi
|
||||
|
||||
# Build Container
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
docker build \
|
||||
--platform="${PLATFORM:-linux/amd64}" \
|
||||
--tag="${CONTAINER_IMAGE:-invokeai}" \
|
||||
${CONTAINER_FLAVOR:+--build-arg="CONTAINER_FLAVOR=${CONTAINER_FLAVOR}"} \
|
||||
|
||||
@@ -49,3 +49,6 @@ CONTAINER_FLAVOR="${CONTAINER_FLAVOR-cuda}"
|
||||
CONTAINER_TAG="${CONTAINER_TAG-"${INVOKEAI_BRANCH##*/}-${CONTAINER_FLAVOR}"}"
|
||||
CONTAINER_IMAGE="${CONTAINER_REGISTRY}/${CONTAINER_REPOSITORY}:${CONTAINER_TAG}"
|
||||
CONTAINER_IMAGE="${CONTAINER_IMAGE,,}"
|
||||
|
||||
# enable docker buildkit
|
||||
export DOCKER_BUILDKIT=1
|
||||
|
||||
@@ -21,10 +21,10 @@ docker run \
|
||||
--tty \
|
||||
--rm \
|
||||
--platform="${PLATFORM}" \
|
||||
--name="${REPOSITORY_NAME,,}" \
|
||||
--hostname="${REPOSITORY_NAME,,}" \
|
||||
--mount=source="${VOLUMENAME}",target=/data \
|
||||
--mount type=bind,source="$(pwd)"/outputs,target=/data/outputs \
|
||||
--name="${REPOSITORY_NAME}" \
|
||||
--hostname="${REPOSITORY_NAME}" \
|
||||
--mount type=volume,volume-driver=local,source="${VOLUMENAME}",target=/data \
|
||||
--mount type=bind,source="$(pwd)"/outputs/,target=/data/outputs/ \
|
||||
${MODELSPATH:+--mount="type=bind,source=${MODELSPATH},target=/data/models"} \
|
||||
${HUGGING_FACE_HUB_TOKEN:+--env="HUGGING_FACE_HUB_TOKEN=${HUGGING_FACE_HUB_TOKEN}"} \
|
||||
--publish=9090:9090 \
|
||||
@@ -32,7 +32,7 @@ docker run \
|
||||
${GPU_FLAGS:+--gpus="${GPU_FLAGS}"} \
|
||||
"${CONTAINER_IMAGE}" ${@:+$@}
|
||||
|
||||
# Remove Trash folder
|
||||
echo -e "\nCleaning trash folder ..."
|
||||
for f in outputs/.Trash*; do
|
||||
if [ -e "$f" ]; then
|
||||
rm -Rf "$f"
|
||||
|
||||
BIN
docs/assets/contributing/html-detail.png
Normal file
BIN
docs/assets/contributing/html-detail.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 470 KiB |
BIN
docs/assets/contributing/html-overview.png
Normal file
BIN
docs/assets/contributing/html-overview.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 457 KiB |
@@ -1,10 +1,18 @@
|
||||
# Invocations
|
||||
|
||||
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
|
||||
Invocations represent a single operation, its inputs, and its outputs. These
|
||||
operations and their outputs can be chained together to generate and modify
|
||||
images.
|
||||
|
||||
## Creating a new invocation
|
||||
|
||||
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
|
||||
To create a new invocation, either find the appropriate module file in
|
||||
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
|
||||
that folder. All invocations in that folder will be discovered and made
|
||||
available to the CLI and API automatically. Invocations make use of
|
||||
[typing](https://docs.python.org/3/library/typing.html) and
|
||||
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
|
||||
into the CLI and API.
|
||||
|
||||
An invocation looks like this:
|
||||
|
||||
@@ -41,34 +49,54 @@ class UpscaleInvocation(BaseInvocation):
|
||||
Each portion is important to implement correctly.
|
||||
|
||||
### Class definition and type
|
||||
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
```
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
|
||||
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring
|
||||
that declares what they do in a single, short line. They should also have a
|
||||
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
|
||||
is what the user will type on the CLI or use in the API to create this
|
||||
invocation. The `command_name` must be unique. The `type` must be assigned to
|
||||
the value of the literal in the type hint.
|
||||
|
||||
### Inputs
|
||||
|
||||
```py
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
||||
```
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
|
||||
| Part | Value | Description |
|
||||
| ---- | ----- | ----------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
|
||||
description, and validation information. For example:
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
|
||||
| Part | Value | Description |
|
||||
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
|
||||
field to be parsed with `None` as a value, which enables linking to previous
|
||||
invocations. All fields should either provide a default value or allow `None` as
|
||||
a value, so that they can be overwritten with a linked output from another
|
||||
invocation.
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as
|
||||
`ImageField`, which protects them from pydantic validation errors (since images
|
||||
only ever come from links).
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
If the `name` also matches, then the field can be **automatically linked** to a
|
||||
previous invocation by name and matching.
|
||||
|
||||
### Invoke Function
|
||||
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
@@ -88,13 +116,22 @@ Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
```
|
||||
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
|
||||
|
||||
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
|
||||
The `invoke` function is the last portion of an invocation. It is provided an
|
||||
`InvocationContext` which contains services to perform work as well as a
|
||||
`session_id` for use as needed. It should return a class with output values that
|
||||
derives from `BaseInvocationOutput`.
|
||||
|
||||
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
|
||||
Before being called, the invocation will have all of its fields set from
|
||||
defaults, inputs, and finally links (overriding in that order).
|
||||
|
||||
Assume that this invocation may be running simultaneously with other
|
||||
invocations, may be running on another machine, or in other interesting
|
||||
scenarios. If you need functionality, please provide it as a service in the
|
||||
`InvocationServices` class, and make sure it can be overridden.
|
||||
|
||||
### Outputs
|
||||
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
@@ -102,4 +139,64 @@ class ImageOutput(BaseInvocationOutput):
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
Output classes look like an invocation class without the invoke method. Prefer
|
||||
to use an existing output class if available, and prefer to name inputs the same
|
||||
as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
## Schema Generation
|
||||
|
||||
Invocation, output and related classes are used to generate an OpenAPI schema.
|
||||
|
||||
### Required Properties
|
||||
|
||||
The schema generation treat all properties with default values as optional. This
|
||||
makes sense internally, but when when using these classes via the generated
|
||||
schema, we end up with e.g. the `ImageOutput` class having its `image` property
|
||||
marked as optional.
|
||||
|
||||
We know that this property will always be present, so the additional logic
|
||||
needed to always check if the property exists adds a lot of extraneous cruft.
|
||||
|
||||
To fix this, we can leverage `pydantic`'s
|
||||
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
|
||||
to mark properties that we know will always be present as required.
|
||||
|
||||
Here's that `ImageOutput` class, without the needed schema customisation:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
|
||||
The generated OpenAPI schema, and all clients/types generated from it, will have
|
||||
the `type` and `image` properties marked as optional, even though we know they
|
||||
will always have a value by the time we can interact with them via the API.
|
||||
|
||||
Here's the same class, but with the schema customisation added:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The resultant schema (and any API client or types generated from it) will now
|
||||
have see `type` as string literal `"image"` and `image` as an `ImageField`
|
||||
object.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
|
||||
83
docs/contributing/LOCAL_DEVELOPMENT.md
Normal file
83
docs/contributing/LOCAL_DEVELOPMENT.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Local Development
|
||||
|
||||
If you are looking to contribute you will need to have a local development
|
||||
environment. See the
|
||||
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
|
||||
full details.
|
||||
|
||||
Broadly this involves cloning the repository, installing the pre-reqs, and
|
||||
InvokeAI (in editable form). Assuming this is working, choose your area of
|
||||
focus.
|
||||
|
||||
## Documentation
|
||||
|
||||
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
|
||||
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
|
||||
written in markdown files under the `./docs` folder and then built into a static
|
||||
website for hosting with GitHub Pages at
|
||||
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
|
||||
|
||||
To contribute to the documentation you'll need to install the dependencies. Note
|
||||
the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[docs]"
|
||||
```
|
||||
|
||||
Now, to run the documentation locally with hot-reloading for changes made.
|
||||
|
||||
```zsh
|
||||
mkdocs serve
|
||||
```
|
||||
|
||||
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
|
||||
access.
|
||||
|
||||
## Backend
|
||||
|
||||
The backend is contained within the `./invokeai/backend` folder structure. To
|
||||
get started however please install the development dependencies.
|
||||
|
||||
From the root of the repository run the following command. Note the use of `"`.
|
||||
|
||||
```zsh
|
||||
pip install ".[test]"
|
||||
```
|
||||
|
||||
This in an optional group of packages which is defined within the
|
||||
`pyproject.toml` and will be required for testing the changes you make the the
|
||||
code.
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
|
||||
be found under the `./tests` folder and can be run with a single `pytest`
|
||||
command. Optionally, to review test coverage you can append `--cov`.
|
||||
|
||||
```zsh
|
||||
pytest --cov
|
||||
```
|
||||
|
||||
Test outcomes and coverage will be reported in the terminal. In addition a more
|
||||
detailed report is created in both XML and HTML format in the `./coverage`
|
||||
folder. The HTML one in particular can help identify missing statements
|
||||
requiring tests to ensure coverage. This can be run by opening
|
||||
`./coverage/html/index.html`.
|
||||
|
||||
For example.
|
||||
|
||||
```zsh
|
||||
pytest --cov; open ./coverage/html/index.html
|
||||
```
|
||||
|
||||
??? info "HTML coverage report output"
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Front End
|
||||
|
||||
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
|
||||
|
||||
--8<-- "invokeai/frontend/web/README.md"
|
||||
@@ -168,11 +168,15 @@ used by Stable Diffusion 1.4 and 1.5.
|
||||
After installation, your `models.yaml` should contain an entry that looks like
|
||||
this one:
|
||||
|
||||
inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
description: SD inpainting v1.5 config:
|
||||
configs/stable-diffusion/v1-inpainting-inference.yaml vae:
|
||||
models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512
|
||||
height: 512
|
||||
```yml
|
||||
inpainting-1.5:
|
||||
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||
description: SD inpainting v1.5
|
||||
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
```
|
||||
|
||||
As shown in the example, you may include a VAE fine-tuning weights file as well.
|
||||
This is strongly recommended.
|
||||
|
||||
@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
||||
At installation time, InvokeAI will ask whether the checker should be
|
||||
activated by default (neither argument given on the command line). The
|
||||
response is stored in the InvokeAI initialization file (usually
|
||||
`.invokeai` in your home directory). You can change the default at any
|
||||
`invokeai.init` in your home directory). You can change the default at any
|
||||
time by opening this file in a text editor and commenting or
|
||||
uncommenting the line `--nsfw_checker`.
|
||||
|
||||
|
||||
@@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
|
||||
masking option:
|
||||
|
||||
```bash
|
||||
invoke> a fluffy cat eating a hotdot
|
||||
invoke> a fluffy cat eating a hotdog
|
||||
Outputs:
|
||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||
|
||||
@@ -17,7 +17,7 @@ notebooks.
|
||||
|
||||
You will need a GPU to perform training in a reasonable length of
|
||||
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
|
||||
library](../installation/070_INSTALL_XFORMERS) to accelerate the
|
||||
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
|
||||
training process further. During training, about ~8 GB is temporarily
|
||||
needed in order to store intermediate models, checkpoints and logs.
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ experimental versions later.
|
||||
sudo apt update
|
||||
sudo apt install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt install python3.10 python3-pip python3.10-venv
|
||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||
```
|
||||
|
||||
@@ -417,7 +417,7 @@ Then type the following commands:
|
||||
|
||||
=== "AMD System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
### Corrupted configuration file
|
||||
|
||||
@@ -148,13 +148,13 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
@@ -315,7 +315,7 @@ installation protocol (important!)
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
```bash
|
||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
|
||||
@@ -110,7 +110,7 @@ recipes are available
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual
|
||||
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
This will be done automatically for you if you use the installer
|
||||
|
||||
@@ -50,7 +50,7 @@ subset that are currently installed are found in
|
||||
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
||||
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
||||
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|
||||
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
||||
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
||||
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
||||
|
||||
@@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
|
||||
brew install opencv
|
||||
```
|
||||
|
||||
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
|
||||
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
|
||||
|
||||
## Linux
|
||||
|
||||
@@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
|
||||
`python`, and then at the `>>>` line type
|
||||
`from patchmatch import patch_match`: It should look like the follwing:
|
||||
`from patchmatch import patch_match`: It should look like the following:
|
||||
|
||||
```py
|
||||
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
|
||||
@@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
|
||||
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
|
||||
|
||||
If you see no errors, then you're ready to go!
|
||||
If you see no errors you're ready to go!
|
||||
|
||||
@@ -11,10 +11,10 @@ if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
exit -1
|
||||
fi
|
||||
|
||||
VERSION=$(cd ..; python -c "from ldm.invoke import __version__ as version; print(version)")
|
||||
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v2.3-latest"
|
||||
LATEST_TAG="v3.0-latest"
|
||||
|
||||
echo Building installer for version $VERSION
|
||||
echo "Be certain that you're in the 'installer' directory before continuing."
|
||||
|
||||
@@ -247,8 +247,8 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
@@ -291,7 +291,7 @@ class InvokeAiInstance:
|
||||
src = Path(__file__).parents[1].expanduser().resolve()
|
||||
# if the above directory contains one of these files, we'll do a source install
|
||||
next(src.glob("pyproject.toml"))
|
||||
next(src.glob("ldm"))
|
||||
next(src.glob("invokeai"))
|
||||
except StopIteration:
|
||||
print("Unable to find a wheel or perform a source install. Giving up.")
|
||||
|
||||
@@ -342,14 +342,14 @@ class InvokeAiInstance:
|
||||
|
||||
introduction()
|
||||
|
||||
from ldm.invoke.config import invokeai_configure
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
# this may change in the future with config refactoring!
|
||||
succeeded = False
|
||||
try:
|
||||
invokeai_configure.main()
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||
@@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
optional_modules = None
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.2"
|
||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
while true
|
||||
do
|
||||
if [ "$0" != "bash" ]; then
|
||||
while true
|
||||
do
|
||||
echo "Do you want to generate images using the"
|
||||
echo "1. command-line interface"
|
||||
echo "2. browser-based UI"
|
||||
@@ -67,29 +67,29 @@ if [ "$0" != "bash" ]; then
|
||||
;;
|
||||
7)
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
echo "Developer Console:"
|
||||
;;
|
||||
8)
|
||||
echo "Developer Console:"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
9)
|
||||
echo "Update:"
|
||||
echo "Update:"
|
||||
invokeai-update
|
||||
;;
|
||||
10)
|
||||
invokeai --help
|
||||
;;
|
||||
[qQ])
|
||||
[qQ])
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid selection"
|
||||
exit;;
|
||||
esac
|
||||
done
|
||||
else # in developer console
|
||||
python --version
|
||||
echo "Press ^D to exit"
|
||||
export PS1="(InvokeAI) \u@\h \w> "
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
After version 2.3 is released, the ldm/invoke modules will be migrated to this location
|
||||
so that we have a proper invokeai distribution. Currently it is only being used for
|
||||
data files.
|
||||
Organization of the source tree:
|
||||
|
||||
app -- Home of nodes invocations and services
|
||||
assets -- Images and other data files used by InvokeAI
|
||||
backend -- Non-user facing libraries, including the rendering
|
||||
core.
|
||||
configs -- Configuration files used at install and run times
|
||||
frontend -- User-facing scripts, including the CLI and the WebUI
|
||||
version -- Current InvokeAI version string, stored
|
||||
in version/invokeai_version.py
|
||||
|
||||
118
invokeai/app/api/dependencies.py
Normal file
118
invokeai/app/api/dependencies.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
def check_internet() -> bool:
|
||||
"""
|
||||
Return true if the internet is reachable.
|
||||
It does this by pinging huggingface.co.
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
host = "http://huggingface.co"
|
||||
try:
|
||||
urllib.request.urlopen(host, timeout=1)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TO DO: Use the config to select the logger rather than use the default
|
||||
# invokeai logging module
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
urls = LocalUrlService()
|
||||
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
|
||||
images_new = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config, logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=image_file_storage,
|
||||
images_new=images_new,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config, logger),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
ApiDependencies.invoker.stop()
|
||||
@@ -1,11 +1,14 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
import threading
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
@@ -16,39 +19,34 @@ class FastAPIEventService(EventServiceBase):
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event = self.__stop_event))
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put(dict(
|
||||
event_name = event_name,
|
||||
payload = payload
|
||||
))
|
||||
|
||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block = False)
|
||||
if not event: # Probably stopping
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get('event_name'),
|
||||
payload = event.get('payload'),
|
||||
middleware_id = self.event_handler_id)
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
raise e # Raise a proper error
|
||||
40
invokeai/app/api/models/images.py
Normal file
40
invokeai/app/api/models/images.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class SavedImage(BaseModel):
|
||||
image_name: str = Field(description="The name of the saved image")
|
||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
||||
created: int = Field(description="The created timestamp of the saved image")
|
||||
47
invokeai/app/api/routers/image_files.py
Normal file
47
invokeai/app/api/routers/image_files.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from fastapi import HTTPException, Path
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.models.image import ImageType
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
|
||||
|
||||
|
||||
@image_files_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of the image to get"),
|
||||
image_name: str = Path(description="The id of the image to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images_new.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
return FileResponse(path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@image_files_router.get(
|
||||
"/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
image_type: ImageType = Path(
|
||||
description="The type of the image whose thumbnail to get"
|
||||
),
|
||||
image_name: str = Path(description="The id of the image whose thumbnail to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images_new.get_path(
|
||||
image_type=image_type, image_name=image_name, thumbnail=True
|
||||
)
|
||||
|
||||
return FileResponse(path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
71
invokeai/app/api/routers/image_records.py
Normal file
71
invokeai/app/api/routers/image_records.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from fastapi import HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
image_records_router = APIRouter(
|
||||
prefix="/v1/images/records", tags=["images", "records"]
|
||||
)
|
||||
|
||||
|
||||
@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record")
|
||||
async def get_image_record(
|
||||
image_type: ImageType = Path(description="The type of the image record to get"),
|
||||
image_name: str = Path(description="The id of the image record to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image record by id"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images_new.get_dto(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@image_records_router.get(
|
||||
"/",
|
||||
operation_id="list_image_records",
|
||||
)
|
||||
async def list_image_records(
|
||||
image_type: ImageType = Query(description="The type of image records to get"),
|
||||
image_category: ImageCategory = Query(
|
||||
description="The kind of image records to get"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of image records to get"),
|
||||
per_page: int = Query(
|
||||
default=10, description="The number of image records per page"
|
||||
),
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image records by type and category"""
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images_new.get_many(
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
|
||||
@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image_record(
|
||||
image_type: ImageType = Query(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image record"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images_new.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
77
invokeai/app/api/routers/images.py
Normal file
77
invokeai/app/api/routers/images.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import io
|
||||
import uuid
|
||||
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||
from invokeai.app.services.models.image_record import ImageRecord
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
415: {"description": "Image upload failed"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def upload_image(
|
||||
file: UploadFile,
|
||||
image_type: ImageType,
|
||||
request: Request,
|
||||
response: Response,
|
||||
image_category: ImageCategory = ImageCategory.IMAGE,
|
||||
) -> ImageRecord:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(contents))
|
||||
except:
|
||||
# Error opening the image
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
image_record = ApiDependencies.invoker.services.images_new.create(
|
||||
image=img,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = image_record.image_url
|
||||
|
||||
return image_record
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500)
|
||||
|
||||
|
||||
|
||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image_record(
|
||||
image_type: ImageType = Query(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image record"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images_new.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
248
invokeai/app/api/routers/models.py
Normal file
248
invokeai/app/api/routers/models.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models() -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
"description": "Model deleted successfully"
|
||||
},
|
||||
404: {
|
||||
"description": "Model not found"
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
42
invokeai/app/api/routers/results.py
Normal file
42
invokeai/app/api/routers/results.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.results import ResultType, ResultWithSession
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
results_router = APIRouter(prefix="/v1/results", tags=["results"])
|
||||
|
||||
|
||||
@results_router.get("/{result_type}/{result_name}", operation_id="get_result")
|
||||
async def get_result(
|
||||
result_type: ResultType = Path(description="The type of result to get"),
|
||||
result_name: str = Path(description="The name of the result to get"),
|
||||
) -> ResultWithSession:
|
||||
"""Gets a result"""
|
||||
|
||||
result = ApiDependencies.invoker.services.results.get(
|
||||
result_id=result_name, result_type=result_type
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@results_router.get(
|
||||
"/",
|
||||
operation_id="list_results",
|
||||
responses={200: {"model": PaginatedResults[ResultWithSession]}},
|
||||
)
|
||||
async def list_results(
|
||||
result_type: ResultType = Query(description="The type of results to get"),
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
"""Gets a list of results"""
|
||||
results = ApiDependencies.invoker.services.results.get_many(
|
||||
result_type=result_type, page=page, per_page=per_page
|
||||
)
|
||||
return results
|
||||
286
invokeai/app/api/routers/sessions.py
Normal file
286
invokeai/app/api/routers/sessions.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ...invocations import *
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/",
|
||||
operation_id="create_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(
|
||||
default=None, description="The graph to initialize the session with"
|
||||
)
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
return session
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
per_page: int = Query(default=10, description="The number of results per page"),
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||
query, page, per_page
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/{session_id}",
|
||||
operation_id="get_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/nodes",
|
||||
operation_id="add_node",
|
||||
responses={
|
||||
200: {"model": str},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="update_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/nodes/{node_path}",
|
||||
operation_id="delete_node",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node to delete"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
"/{session_id}/edges",
|
||||
operation_id="add_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
edge: Edge = Body(description="The edge to add"),
|
||||
) -> GraphExecutionState:
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@session_router.delete(
|
||||
"/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
operation_id="delete_edge",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
) -> GraphExecutionState:
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
session
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="invoke_session",
|
||||
responses={
|
||||
200: {"model": None},
|
||||
202: {"description": "The invocation is queued"},
|
||||
400: {"description": "The session has no invocations ready to invoke"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={
|
||||
202: {"description": "The invocation is canceled"}
|
||||
},
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
@@ -1,36 +1,38 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_socketio import SocketManager
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from fastapi_socketio import SocketManager
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app = app)
|
||||
self.__sio.on('subscribe', handler=self._handle_sub)
|
||||
self.__sio.on('unsubscribe', handler=self._handle_unsub)
|
||||
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name = EventServiceBase.session_event,
|
||||
_func=self._handle_session_event
|
||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||
)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event = event[1]['event'],
|
||||
data = event[1]['data'],
|
||||
room = event[1]['data']['graph_execution_state_id']
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.enter_room(sid, data['session'])
|
||||
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if 'session' in data:
|
||||
self.__sio.leave_room(sid, data['session'])
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
@@ -1,37 +1,39 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
from inspect import signature
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
import uvicorn
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .api.routers import images, sessions
|
||||
from invokeai.app.models import resources
|
||||
import invokeai.backend.util.logging as logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from ..args import Args
|
||||
from .api.routers import image_files, image_records, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title = "Invoke AI",
|
||||
docs_url = None,
|
||||
redoc_url = None
|
||||
)
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers = [local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id = event_handler_id)
|
||||
handlers=[
|
||||
local_handler
|
||||
], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
@@ -48,38 +50,38 @@ socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event('startup')
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
args = args,
|
||||
config = config,
|
||||
event_handler_id = event_handler_id
|
||||
config=config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event('shutdown')
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
# TODO: REMOVE
|
||||
# app.include_router(
|
||||
# invocation.invocation_router,
|
||||
# prefix = '/api')
|
||||
|
||||
app.include_router(
|
||||
sessions.session_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(image_files.image_files_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(image_records.image_records_router, prefix="/api")
|
||||
|
||||
app.include_router(
|
||||
images.images_router,
|
||||
prefix = '/api'
|
||||
)
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@@ -87,10 +89,10 @@ def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title = app.title,
|
||||
description = "An API for invoking AI image operations",
|
||||
version = "1.0.0",
|
||||
routes = app.routes
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
@@ -102,12 +104,12 @@ def custom_openapi():
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in output_schemas['definitions'].items():
|
||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema['title']
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
@@ -115,47 +117,51 @@ def custom_openapi():
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||
outputs_ref = { '$ref': f'#/components/schemas/{output_type_title}' }
|
||||
if 'additionalProperties' not in invoker_schema:
|
||||
invoker_schema['additionalProperties'] = {}
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
invoker_schema['additionalProperties']['outputs'] = outputs_ref
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount('/static', StaticFiles(directory='static/dream_web'), name='static')
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
swagger_favicon_url="/static/favicon.ico"
|
||||
swagger_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/redoc", include_in_schema=False)
|
||||
def overridden_redoc():
|
||||
return get_redoc_html(
|
||||
return get_redoc_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title,
|
||||
redoc_favicon_url="/static/favicon.ico"
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount(
|
||||
"/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui"
|
||||
)
|
||||
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(
|
||||
app = app,
|
||||
host = "0.0.0.0",
|
||||
port = 9090,
|
||||
loop = loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
287
invokeai/app/cli/commands.py
Normal file
287
invokeai/app/cli/commands.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
invoker: Invoker
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
# type: Literal['your_command_name'] = 'your_command_name'
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
subclasses = []
|
||||
toprocess = [cls]
|
||||
while len(toprocess) > 0:
|
||||
next = toprocess.pop(0)
|
||||
next_subclasses = next.__subclasses__()
|
||||
subclasses.extend(next_subclasses)
|
||||
toprocess.extend(next_subclasses)
|
||||
return subclasses
|
||||
|
||||
@classmethod
|
||||
def get_commands(cls):
|
||||
return tuple(BaseCommand.get_all_subclasses())
|
||||
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
"""Run the command. Raise ExitCli to exit."""
|
||||
pass
|
||||
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
type: Literal['exit'] = 'exit'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
type: Literal['help'] = 'help'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
|
||||
|
||||
def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
fields = invocation.__fields__.items()
|
||||
type_hints = get_type_hints(type(invocation))
|
||||
command = [invocation.type]
|
||||
for name, field in fields:
|
||||
if name in ["id", "type"]:
|
||||
continue
|
||||
|
||||
# TODO: add links
|
||||
|
||||
# Skip image fields when serializing command
|
||||
type_hint = type_hints.get(name) or None
|
||||
if type_hint is ImageField or ImageField in get_args(type_hint):
|
||||
continue
|
||||
|
||||
field_value = getattr(invocation, name)
|
||||
field_default = field.default
|
||||
if field_value != field_default:
|
||||
if type_hint is str or str in get_args(type_hint):
|
||||
command.append(f'--{name} "{field_value}"')
|
||||
else:
|
||||
command.append(f"--{name} {field_value}")
|
||||
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
type: Literal['history'] = 'history'
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
count: int = Field(default=5, gt=0, description="The number of history entries to show")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
history = list(get_graph_execution_history(context.get_session()))
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
type: Literal['default'] = 'default'
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
field: str = Field(description="The field to set the default for")
|
||||
value: str = Field(description="The value to set the default to, or None to clear the default")
|
||||
# fmt: on
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
if self.value is None:
|
||||
if self.field in context.defaults:
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
||||
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.execution_graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
168
invokeai/app/cli/completer.py
Normal file
168
invokeai/app/cli/completer.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Readline helper functions for cli_app.py
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object.
|
||||
"""
|
||||
import atexit
|
||||
import readline
|
||||
import shlex
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager, Globals
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
|
||||
class Completer(object):
|
||||
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
self.linebuffer = None
|
||||
self.manager = model_manager
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
"""
|
||||
Complete commands and switches fromm the node CLI command line.
|
||||
Switches are determined in a context-specific manner.
|
||||
"""
|
||||
|
||||
buffer = readline.get_line_buffer()
|
||||
if state == 0:
|
||||
options = None
|
||||
try:
|
||||
current_command, current_switch = self.get_current_command(buffer)
|
||||
options = self.get_command_options(current_command, current_switch)
|
||||
except IndexError:
|
||||
pass
|
||||
options = options or list(self.parse_commands().keys())
|
||||
|
||||
if not text: # first time
|
||||
self.matches = options
|
||||
else:
|
||||
self.matches = [s for s in options if s and s.startswith(text)]
|
||||
|
||||
try:
|
||||
match = self.matches[state]
|
||||
except IndexError:
|
||||
match = None
|
||||
return match
|
||||
|
||||
@classmethod
|
||||
def get_commands(self)->List[object]:
|
||||
"""
|
||||
Return a list of all the client commands and invocations.
|
||||
"""
|
||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||
|
||||
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||
"""
|
||||
Parse the readline buffer to find the most recent command and its switch.
|
||||
"""
|
||||
if len(buffer)==0:
|
||||
return None, None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
switch = None
|
||||
for t in tokens:
|
||||
if t[0].isalpha():
|
||||
if switch is None:
|
||||
command = t
|
||||
else:
|
||||
switch = t
|
||||
# don't try to autocomplete switches that are already complete
|
||||
if switch and buffer.endswith(' '):
|
||||
switch=None
|
||||
return command or '', switch or ''
|
||||
|
||||
def parse_commands(self)->Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict in which the keys are the command name
|
||||
and the values are the parameters the command takes.
|
||||
"""
|
||||
result = dict()
|
||||
for command in self.commands:
|
||||
hints = get_type_hints(command)
|
||||
name = get_args(hints['type'])[0]
|
||||
result.update({name:hints})
|
||||
return result
|
||||
|
||||
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||
"""
|
||||
Return all the parameters that can be passed to the command as
|
||||
command-line switches. Returns None if the command is unrecognized.
|
||||
"""
|
||||
parsed_commands = self.parse_commands()
|
||||
if command not in parsed_commands:
|
||||
return None
|
||||
|
||||
# handle switches in the format "-foo=bar"
|
||||
argument = None
|
||||
if switch and '=' in switch:
|
||||
switch, argument = switch.split('=')
|
||||
|
||||
parameter = switch.strip('-')
|
||||
if parameter in parsed_commands[command]:
|
||||
if argument is None:
|
||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
else:
|
||||
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||
else:
|
||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||
|
||||
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||
"""
|
||||
Given a parameter type (such as Literal), offers autocompletions.
|
||||
"""
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == 'model':
|
||||
return self.manager.model_names()
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(True)
|
||||
except:
|
||||
pass
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(Globals.root, ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
386
invokeai/app/cli_app.py
Normal file
386
invokeai/app/cli_app.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
|
||||
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
"--link",
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
"--link_node",
|
||||
"-ln",
|
||||
action="append",
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
|
||||
parser.exit = exit
|
||||
subparsers = parser.add_subparsers(dest="type")
|
||||
|
||||
# Create subparsers for each invocation
|
||||
invocations = BaseInvocation.get_all_subclasses()
|
||||
add_parsers(subparsers, invocations, add_arguments=add_invocation_args)
|
||||
|
||||
# Create subparsers for each command
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
# Remove invalid fields
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
|
||||
class SessionError(Exception):
|
||||
"""Raised when a session error has occurred"""
|
||||
pass
|
||||
|
||||
|
||||
def invoke_all(context: CliContext):
|
||||
"""Runs all invocations in the specified session"""
|
||||
context.invoker.invoke(context.session, invoke_all=True)
|
||||
while not context.get_session().is_complete():
|
||||
# Wait some time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
|
||||
while True:
|
||||
try:
|
||||
cmd_input = input("invoke> ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl-c exits
|
||||
break
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
if cmd is None or cmd.strip() == "":
|
||||
raise InvalidArgs("Empty command")
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
||||
|
||||
# Override defaults
|
||||
for field_name, field_default in context.defaults.items():
|
||||
if field_name in args:
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
# Invoke all current nodes to preserve operation order
|
||||
invoke_all(context)
|
||||
|
||||
# Run the command
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
)
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
node_id = link
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
new_invocations.append((command.command, edges))
|
||||
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
except SystemExit:
|
||||
continue
|
||||
|
||||
invoker.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_cli()
|
||||
@@ -4,5 +4,9 @@ __all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
if (
|
||||
f != "__init__.py"
|
||||
and os.path.isfile("%s/%s" % (dirname, f))
|
||||
and f[-3:] == ".py"
|
||||
):
|
||||
__all__.append(f[:-3])
|
||||
@@ -1,10 +1,15 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@@ -70,5 +75,60 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
64
invokeai/app/invocations/collections.py
Normal file
64
invokeai/app/invocations/collections.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||
)
|
||||
246
invokeai/app/invocations/compel.py
Normal file
246
invokeai/app/invocations/compel.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
# TODO: support legacy blend?
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
69
invokeai/app/invocations/cv.py
Normal file
69
invokeai/app/invocations/cv.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
"""Simple inpaint using opencv."""
|
||||
#fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
||||
|
||||
# Inpaint
|
||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||
|
||||
# Convert back to Pillow
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_inpainted,
|
||||
)
|
||||
312
invokeai/app/invocations/generate.py
Normal file
312
invokeai/app/invocations/generate.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images_new.create(
|
||||
image=generate_output.image,
|
||||
image_type=ImageType.RESULT,
|
||||
image_category=ImageCategory.IMAGE,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
# image_type = ImageType.RESULT
|
||||
# image_name = context.services.images.create_name(
|
||||
# context.graph_execution_state_id, self.id
|
||||
# )
|
||||
|
||||
# metadata = context.services.metadata.build_metadata(
|
||||
# session_id=context.graph_execution_state_id, node=self
|
||||
# )
|
||||
|
||||
# context.services.images.save(
|
||||
# image_type, image_name, generate_output.image, metadata
|
||||
# )
|
||||
|
||||
# context.services.images_db.set(
|
||||
# id=image_name,
|
||||
# image_type=ImageType.RESULT,
|
||||
# image_category=ImageCategory.IMAGE,
|
||||
# session_id=context.graph_execution_state_id,
|
||||
# node_id=self.id,
|
||||
# metadata=GeneratedImageOrLatentsMetadata(),
|
||||
# )
|
||||
|
||||
return build_image_output(
|
||||
image_type=image_dto.image_type,
|
||||
image_name=image_dto.image_name,
|
||||
image=generate_output.image,
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||
)
|
||||
fit: bool = Field(
|
||||
default=True,
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
# TODO: can this return multiple results? Should it?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image,
|
||||
)
|
||||
367
invokeai/app/invocations/image.py
Normal file
367
invokeai/app/invocations/image.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
image_type: ImageType, image_name: str, image: Image.Image
|
||||
) -> ImageOutput:
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
image_field = ImageField(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image_type, self.image_name)
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image_type,
|
||||
image_name=self.image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image.image_type,
|
||||
image_name=self.image.image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
)
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_crop,
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
|
||||
# Inputs
|
||||
base_image: ImageField = Field(default=None, description="The base image")
|
||||
image: ImageField = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get(
|
||||
self.base_image.image_type, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
min_y = min(0, self.y)
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(
|
||||
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
||||
)
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=new_image,
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["blur"] = "blur"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
if self.blur_type == "gaussian"
|
||||
else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
numpy.minimum(
|
||||
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
||||
)
|
||||
* 255
|
||||
)
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
image_type = ImageType.INTERMEDIATE
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
)
|
||||
233
invokeai/app/invocations/infill.py
Normal file
233
invokeai/app/invocations/infill.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image)
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
482
invokeai/app/invocations/latent.py
Normal file
482
invokeai/app/invocations/latent.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional, Union
|
||||
import einops
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_file_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
#fmt: on
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
|
||||
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(latent_channels, 4)
|
||||
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
||||
generator = torch.Generator(device=use_device).manual_seed(seed)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
device=use_device,
|
||||
generator=generator,
|
||||
).to(device)
|
||||
# if self.perlin > 0.0:
|
||||
# perlin_noise = self.get_perlin_noise(
|
||||
# width // self.downsampling_factor, height // self.downsampling_factor
|
||||
# )
|
||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(choose_torch_device())
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
# Latent to image
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=image
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
100
invokeai/app/invocations/math.py
Normal file
100
invokeai/app/invocations/math.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
@@ -1,9 +1,22 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .baseinvocation import BaseInvocationOutput
|
||||
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
type: Literal['prompt'] = 'prompt'
|
||||
#fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
#fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'prompt',
|
||||
]
|
||||
}
|
||||
56
invokeai/app/invocations/reconstruct.py
Normal file
56
invokeai/app/invocations/reconstruct.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
#fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["restoration", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
strength=self.strength, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
60
invokeai/app/invocations/upscale.py
Normal file
60
invokeai/app/invocations/upscale.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
#fmt: off
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
strength=0.0, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=results[0][0]
|
||||
)
|
||||
14
invokeai/app/invocations/util/choose_model.py
Normal file
14
invokeai/app/invocations/util/choose_model.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
3
invokeai/app/models/exceptions.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
pass
|
||||
51
invokeai/app/models/image.py
Normal file
51
invokeai/app/models/image.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.enum import MetaEnum
|
||||
|
||||
|
||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
||||
"""The type of an image."""
|
||||
|
||||
RESULT = "results"
|
||||
UPLOAD = "uploads"
|
||||
INTERMEDIATE = "intermediates"
|
||||
|
||||
|
||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
||||
|
||||
IMAGE = "image"
|
||||
CONTROL_IMAGE = "control_image"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
try:
|
||||
ImageType(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
59
invokeai/app/models/metadata.py
Normal file
59
invokeai/app/models/metadata.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""
|
||||
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Also includes any metadata from the image's PNG tEXt chunks.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
positive_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The positive conditioning."
|
||||
)
|
||||
negative_conditioning: Optional[StrictStr] = Field(
|
||||
default=None, description="The negative conditioning."
|
||||
)
|
||||
width: Optional[StrictInt] = Field(
|
||||
default=None, description="Width of the image/tensor in pixels."
|
||||
)
|
||||
height: Optional[StrictInt] = Field(
|
||||
default=None, description="Height of the image/tensor in pixels."
|
||||
)
|
||||
seed: Optional[StrictInt] = Field(
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
cfg_scale: Optional[StrictFloat] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
steps: Optional[StrictInt] = Field(
|
||||
default=None, description="The number of steps used for inference."
|
||||
)
|
||||
scheduler: Optional[StrictStr] = Field(
|
||||
default=None, description="The scheduler used for inference."
|
||||
)
|
||||
model: Optional[StrictStr] = Field(
|
||||
default=None, description="The model used for inference."
|
||||
)
|
||||
strength: Optional[StrictFloat] = Field(
|
||||
default=None,
|
||||
description="The strength used for image-to-image/tensor-to-tensor.",
|
||||
)
|
||||
image: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial image."
|
||||
)
|
||||
tensor: Optional[StrictStr] = Field(
|
||||
default=None, description="The ID of the initial tensor."
|
||||
)
|
||||
# Pending model refactor:
|
||||
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
|
||||
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
|
||||
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
|
||||
extra: Optional[StrictStr] = Field(
|
||||
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
|
||||
)
|
||||
28
invokeai/app/models/resources.py
Normal file
28
invokeai/app/models/resources.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# TODO: Make a new model for this
|
||||
from enum import Enum
|
||||
|
||||
from invokeai.app.util.enum import MetaEnum
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=MetaEnum):
|
||||
"""The type of a resource."""
|
||||
|
||||
IMAGES = "images"
|
||||
TENSORS = "tensors"
|
||||
|
||||
|
||||
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
# """The origin of a resource (eg image or tensor)."""
|
||||
|
||||
# RESULTS = "results"
|
||||
# UPLOADS = "uploads"
|
||||
# INTERMEDIATES = "intermediates"
|
||||
|
||||
|
||||
|
||||
class TensorKind(str, Enum, metaclass=MetaEnum):
|
||||
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
|
||||
|
||||
IMAGE_LATENTS = "image_latents"
|
||||
CONDITIONING = "conditioning"
|
||||
OTHER = "other"
|
||||
578
invokeai/app/services/db.ipynb
Normal file
578
invokeai/app/services/db.ipynb
Normal file
@@ -0,0 +1,578 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from abc import ABC, abstractmethod\n",
|
||||
"from enum import Enum\n",
|
||||
"import enum\n",
|
||||
"import sqlite3\n",
|
||||
"import threading\n",
|
||||
"from typing import Optional, Type, TypeVar, Union\n",
|
||||
"from PIL.Image import Image as PILImage\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from torch import Tensor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class ResourceOrigin(str, Enum):\n",
|
||||
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
|
||||
"\n",
|
||||
" RESULTS = \"results\"\n",
|
||||
" UPLOADS = \"uploads\"\n",
|
||||
" INTERMEDIATES = \"intermediates\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ImageKind(str, Enum):\n",
|
||||
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
|
||||
"\n",
|
||||
" IMAGE = \"image\"\n",
|
||||
" CONTROL_IMAGE = \"control_image\"\n",
|
||||
" OTHER = \"other\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class TensorKind(str, Enum):\n",
|
||||
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
|
||||
"\n",
|
||||
" IMAGE_LATENTS = \"image_latents\"\n",
|
||||
" CONDITIONING = \"conditioning\"\n",
|
||||
" OTHER = \"other\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" delimiter = \", \"\n",
|
||||
" values = [f\"('{e.value}')\" for e in enum]\n",
|
||||
" return delimiter.join(values)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_sql_table_from_enum(\n",
|
||||
" enum: Type[Enum],\n",
|
||||
" table_name: str,\n",
|
||||
" primary_key_name: str,\n",
|
||||
" conn: sqlite3.Connection,\n",
|
||||
" cursor: sqlite3.Cursor,\n",
|
||||
" lock: threading.Lock,\n",
|
||||
"):\n",
|
||||
" \"\"\"\n",
|
||||
" Creates and populates a table to be used as a functional enum.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" values_string = create_sql_values_string_from_string_enum(enum)\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" f\"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
|
||||
" {primary_key_name} TEXT PRIMARY KEY\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" f\"\"\"--sql\n",
|
||||
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
"# create_sql_table_from_enum(\n",
|
||||
"# enum=ResourceOrigin,\n",
|
||||
"# table_name=\"resource_origins\",\n",
|
||||
"# primary_key_name=\"origin_name\",\n",
|
||||
"# conn=conn,\n",
|
||||
"# cursor=cursor,\n",
|
||||
"# lock=lock,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`image_kinds` functions as an enum for the ImageType model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" # create_sql_table_from_enum(\n",
|
||||
" # enum=ImageKind,\n",
|
||||
" # table_name=\"image_kinds\",\n",
|
||||
" # primary_key_name=\"kind_name\",\n",
|
||||
" # conn=conn,\n",
|
||||
" # cursor=cursor,\n",
|
||||
" # lock=lock,\n",
|
||||
" # )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`tensor_kinds` functions as an enum for the TensorType model.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" # create_sql_table_from_enum(\n",
|
||||
" # enum=TensorKind,\n",
|
||||
" # table_name=\"tensor_kinds\",\n",
|
||||
" # primary_key_name=\"kind_name\",\n",
|
||||
" # conn=conn,\n",
|
||||
" # cursor=cursor,\n",
|
||||
" # lock=lock,\n",
|
||||
" # )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images` stores all images, regardless of type\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images (\n",
|
||||
" id TEXT PRIMARY KEY,\n",
|
||||
" origin TEXT,\n",
|
||||
" image_kind TEXT,\n",
|
||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
||||
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_results` stores additional data specific to `results` images.\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_results (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_intermediates` stores additional data specific to `intermediates` images\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"`images_metadata` stores basic metadata for any image type\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
|
||||
" images_id TEXT PRIMARY KEY,\n",
|
||||
" metadata TEXT,\n",
|
||||
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors` table: stores references to tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors (\n",
|
||||
" id TEXT PRIMARY KEY,\n",
|
||||
" origin TEXT,\n",
|
||||
" tensor_kind TEXT,\n",
|
||||
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
|
||||
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
|
||||
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_results` stores additional data specific to `result` tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" session_id TEXT NOT NULL,\n",
|
||||
" node_id TEXT NOT NULL,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
|
||||
" try:\n",
|
||||
" lock.acquire()\n",
|
||||
"\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
|
||||
" tensors_id TEXT PRIMARY KEY,\n",
|
||||
" metadata TEXT,\n",
|
||||
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
|
||||
" );\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" cursor.execute(\n",
|
||||
" \"\"\"--sql\n",
|
||||
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
|
||||
" \"\"\"\n",
|
||||
" )\n",
|
||||
" conn.commit()\n",
|
||||
" finally:\n",
|
||||
" lock.release()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
|
||||
"if (os.path.exists(db_path)):\n",
|
||||
" os.remove(db_path)\n",
|
||||
"\n",
|
||||
"conn = sqlite3.connect(\n",
|
||||
" db_path, check_same_thread=False\n",
|
||||
")\n",
|
||||
"cursor = conn.cursor()\n",
|
||||
"lock = threading.Lock()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=ResourceOrigin,\n",
|
||||
" table_name=\"resource_origins\",\n",
|
||||
" primary_key_name=\"origin_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=ImageKind,\n",
|
||||
" table_name=\"image_kinds\",\n",
|
||||
" primary_key_name=\"kind_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"create_sql_table_from_enum(\n",
|
||||
" enum=TensorKind,\n",
|
||||
" table_name=\"tensor_kinds\",\n",
|
||||
" primary_key_name=\"kind_name\",\n",
|
||||
" conn=conn,\n",
|
||||
" cursor=cursor,\n",
|
||||
" lock=lock,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_images_table(conn, cursor, lock)\n",
|
||||
"create_images_results_table(conn, cursor, lock)\n",
|
||||
"create_images_intermediates_table(conn, cursor, lock)\n",
|
||||
"create_images_metadata_table(conn, cursor, lock)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"create_tensors_table(conn, cursor, lock)\n",
|
||||
"create_tensors_results_table(conn, cursor, lock)\n",
|
||||
"create_tensors_intermediates_table(conn, cursor, lock)\n",
|
||||
"create_tensors_metadata_table(conn, cursor, lock)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from pydantic import StrictStr\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
|
||||
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
|
||||
"\n",
|
||||
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
|
||||
"\n",
|
||||
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" positive_conditioning: Optional[StrictStr] = Field(\n",
|
||||
" default=None, description=\"The positive conditioning.\"\n",
|
||||
" )\n",
|
||||
" negative_conditioning: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The negative conditioning.\"\n",
|
||||
" )\n",
|
||||
" width: Optional[int] = Field(\n",
|
||||
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
|
||||
" )\n",
|
||||
" height: Optional[int] = Field(\n",
|
||||
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
|
||||
" )\n",
|
||||
" seed: Optional[int] = Field(\n",
|
||||
" default=None, description=\"The seed used for noise generation.\"\n",
|
||||
" )\n",
|
||||
" cfg_scale: Optional[float] = Field(\n",
|
||||
" default=None, description=\"The classifier-free guidance scale.\"\n",
|
||||
" )\n",
|
||||
" steps: Optional[int] = Field(\n",
|
||||
" default=None, description=\"The number of steps used for inference.\"\n",
|
||||
" )\n",
|
||||
" scheduler: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The scheduler used for inference.\"\n",
|
||||
" )\n",
|
||||
" model: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The model used for inference.\"\n",
|
||||
" )\n",
|
||||
" strength: Optional[float] = Field(\n",
|
||||
" default=None,\n",
|
||||
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
|
||||
" )\n",
|
||||
" image: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The ID of the initial image.\"\n",
|
||||
" )\n",
|
||||
" tensor: Optional[str] = Field(\n",
|
||||
" default=None, description=\"The ID of the initial tensor.\"\n",
|
||||
" )\n",
|
||||
" # Pending model refactor:\n",
|
||||
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
|
||||
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
|
||||
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
64
invokeai/app/services/default_graphs.py
Normal file
64
invokeai/app/services/default_graphs.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
107
invokeai/app/services/events.py
Normal file
107
invokeai/app/services/events.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Optional
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
image_url: Optional[str] = None,
|
||||
thumbnail_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
result=result,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
1194
invokeai/app/services/graph.py
Normal file
1194
invokeai/app/services/graph.py
Normal file
File diff suppressed because it is too large
Load Diff
180
invokeai/app/services/image_file_storage.py
Normal file
180
invokeai/app/services/image_file_storage.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
class ImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
# # TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
pnginfo: Optional[PngInfo] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
|
||||
class DiskImageFileStorage(ImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
except FileNotFoundError as e:
|
||||
raise ImageFileStorageBase.ImageFileNotFoundException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
pnginfo: Optional[PngInfo] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
except Exception as e:
|
||||
raise ImageFileStorageBase.ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
except Exception as e:
|
||||
raise ImageFileStorageBase.ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
318
invokeai/app/services/image_record_storage.py
Normal file
318
invokeai/app/services/image_record_storage.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import datetime
|
||||
from typing import Optional
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
)
|
||||
from invokeai.app.services.util.create_enum_table import create_enum_table
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
|
||||
class ImageRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the image record store."""
|
||||
|
||||
class ImageRecordNotFoundException(Exception):
|
||||
"""Raised when an image record is not found."""
|
||||
|
||||
def __init__(self, message="Image record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageRecordSaveException(Exception):
|
||||
"""Raised when an image record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
class ImageRecordDeleteException(Exception):
|
||||
"""Raised when an image record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
created_at: str = datetime.datetime.utcnow().isoformat(),
|
||||
) -> None:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id TEXT PRIMARY KEY,
|
||||
image_type TEXT, -- non-nullable via foreign key constraint
|
||||
image_category TEXT, -- non-nullable via foreign key constraint
|
||||
session_id TEXT, -- nullable
|
||||
node_id TEXT, -- nullable
|
||||
metadata TEXT, -- nullable
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
|
||||
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the tables for image-related enums
|
||||
create_enum_table(
|
||||
enum=ImageType,
|
||||
table_name="image_types",
|
||||
primary_key_name="type_name",
|
||||
cursor=self._cursor,
|
||||
)
|
||||
|
||||
create_enum_table(
|
||||
enum=ImageCategory,
|
||||
table_name="image_categories",
|
||||
primary_key_name="category_name",
|
||||
cursor=self._cursor,
|
||||
)
|
||||
|
||||
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tag_name TEXT UNIQUE NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images_tags` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_tags (
|
||||
image_id TEXT,
|
||||
tag_id INTEGER,
|
||||
PRIMARY KEY (image_id, tag_id),
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images_favorites` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_favorites (
|
||||
image_id TEXT PRIMARY KEY,
|
||||
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = self._cursor.fetchone()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise self.ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise self.ImageRecordNotFoundException
|
||||
|
||||
return deserialize_image_record(result)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(image_type.value, image_category.value, per_page, page * per_page),
|
||||
)
|
||||
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
images = list(map(lambda r: deserialize_image_record(r), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*) FROM images
|
||||
WHERE image_type = ? AND image_category = ?
|
||||
""",
|
||||
(image_type.value, image_category.value),
|
||||
)
|
||||
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordStorageBase.ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
created_at: str,
|
||||
) -> None:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else metadata.json(exclude_none=True)
|
||||
)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
id,
|
||||
image_type,
|
||||
image_category,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
created_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_type.value,
|
||||
image_category.value,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
created_at,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
355
invokeai/app/services/images.py
Normal file
355
invokeai/app/services/images.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from logging import Logger
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
ImageDTO,
|
||||
image_record_to_dto,
|
||||
)
|
||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
"""
|
||||
High-level service for image management.
|
||||
|
||||
Provides methods for creating, retrieving, and deleting images.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
"""Gets an image's path"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
"""Adds a tag to an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
"""Removes a tag from an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
"""Favorites an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
"""Unfavorites an image."""
|
||||
pass
|
||||
|
||||
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
) -> ImageDTO:
|
||||
image_name = self._create_image_name(
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
timestamp = get_iso_timestamp()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
else:
|
||||
pnginfo = None
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self._services.files.save(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
pnginfo=pnginfo,
|
||||
)
|
||||
|
||||
self._services.records.save(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
created_at=timestamp,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_type, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_category=image_category,
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
created_at=timestamp,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
except ImageRecordStorageBase.ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
except ImageFileStorageBase.ImageFileSaveException:
|
||||
self._services.logger.error("Failed to save image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem saving image record and file")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_type, image_name)
|
||||
except ImageFileStorageBase.ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_type, image_name)
|
||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_type, image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_type, image_name),
|
||||
self._services.urls.get_image_url(image_type, image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
except ImageRecordStorageBase.ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
image_type,
|
||||
image_category,
|
||||
page,
|
||||
per_page,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(image_type, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
image_type, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
page=results.page,
|
||||
pages=results.pages,
|
||||
per_page=results.per_page,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
try:
|
||||
self._services.files.delete(image_type, image_name)
|
||||
self._services.records.delete(image_type, image_name)
|
||||
except ImageRecordStorageBase.ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
except ImageFileStorageBase.ImageFileDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image file")
|
||||
raise
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem deleting image record and file")
|
||||
raise e
|
||||
|
||||
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
|
||||
|
||||
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
|
||||
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
|
||||
|
||||
def favorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
raise NotImplementedError("The 'favorite' method is not implemented yet.")
|
||||
|
||||
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
|
||||
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
|
||||
|
||||
def _create_image_name(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Create a unique image name."""
|
||||
uuid_str = str(uuid.uuid4())
|
||||
|
||||
if node_id is not None and session_id is not None:
|
||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||
|
||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||
68
invokeai/app/services/invocation_queue.py
Normal file
68
invokeai/app/services/invocation_queue.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
__cancellations: dict[str, float]
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
self.__cancellations = dict()
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
||||
while isinstance(item, InvocationQueueItem) \
|
||||
and item.graph_execution_state_id in self.__cancellations \
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||
del self.__cancellations[graph_execution_state_id]
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
if graph_execution_state_id not in self.__cancellations:
|
||||
self.__cancellations[graph_execution_state_id] = time.time()
|
||||
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
return graph_execution_state_id in self.__cancellations
|
||||
60
invokeai/app/services/invocation_services.py
Normal file
60
invokeai/app/services/invocation_services.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.backend import ModelManager
|
||||
from .events import EventServiceBase
|
||||
from .latent_storage import LatentsStorageBase
|
||||
from .image_file_storage import ImageFileStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageFileStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
images_new: ImageService
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: Logger,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageFileStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
images_new: ImageService,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.images_new = images_new
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_services import InvocationServices
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
class Invoker:
|
||||
@@ -14,14 +15,13 @@ class Invoker:
|
||||
|
||||
services: InvocationServices
|
||||
|
||||
def __init__(self,
|
||||
services: InvocationServices
|
||||
):
|
||||
def __init__(self, services: InvocationServices):
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> str|None:
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
|
||||
# Get the next invocation
|
||||
@@ -33,58 +33,52 @@ class Invoker:
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f'queueing item {invocation.id}')
|
||||
self.services.queue.put(InvocationQueueItem(
|
||||
#session_id = session.id,
|
||||
graph_execution_state_id = graph_execution_state.id,
|
||||
invocation_id = invocation.id,
|
||||
invoke_all = invoke_all
|
||||
))
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
||||
return invocation.id
|
||||
|
||||
|
||||
def create_execution_state(self, graph: Graph|None = None) -> GraphExecutionState:
|
||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph = Graph() if graph is None else graph)
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, 'start', None)
|
||||
start_op = getattr(service, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(self)
|
||||
|
||||
|
||||
def __stop_service(self, service) -> None:
|
||||
# Call stop() method on any services that have it
|
||||
stop_op = getattr(service, 'stop', None)
|
||||
stop_op = getattr(service, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(self)
|
||||
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Starts the invoker. This is called automatically when the invoker is created."""
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC):
|
||||
pass
|
||||
pass
|
||||
@@ -1,19 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from typing import Callable, TypeVar, Generic
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
items: list[T] = Field(description = "Items")
|
||||
page: int = Field(description = "Current Page")
|
||||
pages: int = Field(description = "Total number of pages")
|
||||
per_page: int = Field(description = "Number of items per page")
|
||||
total: int = Field(description = "Total number of items in result")
|
||||
|
||||
#fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
@@ -24,6 +26,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
pass
|
||||
@@ -37,7 +40,9 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
@@ -51,7 +56,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
def _on_changed(self, item: T) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
93
invokeai/app/services/latent_storage.py
Normal file
93
invokeai/app/services/latent_storage.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
os.remove(latent_path)
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
|
||||
118
invokeai/app/services/metadata.py
Normal file
118
invokeai/app/services/metadata.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.models.image import ImageType, is_image_type
|
||||
|
||||
|
||||
class MetadataImageField(TypedDict):
|
||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||
|
||||
image_type: ImageType
|
||||
image_name: str
|
||||
|
||||
|
||||
class MetadataLatentsField(TypedDict):
|
||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||
|
||||
latents_name: str
|
||||
|
||||
|
||||
class MetadataColorField(TypedDict):
|
||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
a: int
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str,
|
||||
None
|
||||
| str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| MetadataImageField
|
||||
| MetadataLatentsField
|
||||
| MetadataColorField,
|
||||
]
|
||||
|
||||
|
||||
class InvokeAIMetadata(TypedDict, total=False):
|
||||
"""InvokeAI-specific metadata format."""
|
||||
|
||||
session_id: Optional[str]
|
||||
node: Optional[NodeMetadata]
|
||||
|
||||
|
||||
def build_invokeai_metadata_pnginfo(
|
||||
metadata: InvokeAIMetadata | None,
|
||||
) -> PngImagePlugin.PngInfo:
|
||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
|
||||
return pnginfo
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_metadata(
|
||||
self, session_id: str, node: BaseModel
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
pass
|
||||
|
||||
# @abstractmethod
|
||||
# def create_metadata(self, session_id: str, node_id: str) -> dict:
|
||||
# """Creates metadata for a result"""
|
||||
# pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
|
||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
|
||||
try:
|
||||
info = image.info.get("invokeai")
|
||||
|
||||
if type(info) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(info)
|
||||
|
||||
if type(loaded_metadata) is not dict:
|
||||
return None
|
||||
|
||||
if len(loaded_metadata.items()) == 0:
|
||||
return None
|
||||
|
||||
return loaded_metadata
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Retrieves an image's metadata as a dict"""
|
||||
loaded_metadata = self._load_metadata(image)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||
|
||||
return metadata
|
||||
121
invokeai/app/services/model_manager_initializer.py
Normal file
121
invokeai/app/services/model_manager_initializer.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers # type: ignore
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = config.embedding_path
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# migrate legacy models
|
||||
ModelManager.migrate_models()
|
||||
|
||||
# creating the model manager
|
||||
try:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = 'float16' if config.precision=='float16' \
|
||||
else 'float32' if config.precision=='float32' \
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(config.conf),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
|
||||
)
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
if yes_to_all is not None:
|
||||
for arg in yes_to_all.split():
|
||||
sys.argv.append(arg)
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
# TODO: Figure out how to restart
|
||||
# print('** InvokeAI will now restart')
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
||||
71
invokeai/app/services/models/image_record.py
Normal file
71
invokeai/app/services/models/image_record.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import datetime
|
||||
import sqlite3
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
"""Deserialized image record."""
|
||||
|
||||
image_name: str = Field(description="The name of the image.")
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
session_id: Optional[str] = Field(default=None, description="The session ID.")
|
||||
node_id: Optional[str] = Field(default=None, description="The node ID.")
|
||||
metadata: Optional[ImageMetadata] = Field(
|
||||
default=None, description="The image's metadata."
|
||||
)
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord):
|
||||
"""Deserialized image record with URLs."""
|
||||
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
image_name=image_record.image_name,
|
||||
image_type=image_record.image_type,
|
||||
image_category=image_record.image_category,
|
||||
created_at=image_record.created_at,
|
||||
session_id=image_record.session_id,
|
||||
node_id=image_record.node_id,
|
||||
metadata=image_record.metadata,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
|
||||
"""Deserializes an image record."""
|
||||
|
||||
image_dict = dict(image_row)
|
||||
|
||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||
|
||||
raw_metadata = image_dict.get("metadata", "{}")
|
||||
|
||||
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_dict.get("id", "unknown"),
|
||||
session_id=image_dict.get("session_id", None),
|
||||
node_id=image_dict.get("node_id", None),
|
||||
metadata=metadata,
|
||||
image_type=image_type,
|
||||
image_category=ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.IMAGE.value)
|
||||
),
|
||||
created_at=image_dict.get("created_at", get_iso_timestamp()),
|
||||
)
|
||||
174
invokeai/app/services/processor.py
Normal file
174
invokeai/app/services/processor.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
from typing import Any, TypeGuard
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.models.image import ImageType
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name="invoker_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
)
|
||||
invocation = graph_execution_state.execution_graph.get_node(
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
def is_image_output(obj: Any) -> TypeGuard[ImageOutput]:
|
||||
return obj.__class__ == ImageOutput
|
||||
|
||||
outputs_dict = outputs.dict()
|
||||
|
||||
if is_image_output(outputs):
|
||||
image_url = self.__invoker.services.images_new.get_url(
|
||||
ImageType.RESULT, outputs.image.image_name
|
||||
)
|
||||
thumbnail_url = self.__invoker.services.images_new.get_url(
|
||||
ImageType.RESULT, outputs.image.image_name, True
|
||||
)
|
||||
else:
|
||||
image_url = None
|
||||
thumbnail_url = None
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs_dict,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
657
invokeai/app/services/proposeddesign.py
Normal file
657
invokeai/app/services/proposeddesign.py
Normal file
@@ -0,0 +1,657 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import enum
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Type, TypeVar, Union
|
||||
from PIL.Image import Image as PILImage
|
||||
from pydantic import BaseModel, Field
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
|
||||
"""
|
||||
Substantial proposed changes to the management of images and tensor.
|
||||
|
||||
tl;dr:
|
||||
With the upcoming move to latents-only nodes, we need to handle metadata differently. After struggling with this unsuccessfully - trying to smoosh it in to the existing setup - I believe we need to expand the scope of the refactor to include the management of images and latents - and make `latents` a special case of `tensor`.
|
||||
|
||||
full story:
|
||||
The consensus for tensor-only nodes' metadata was to traverse the execution graph and grab the core parameters to write to the image. This was straightforward, and I've written functions to find the nearest t2l/l2l, noise, and compel nodes and build the metadata from those.
|
||||
|
||||
But struggling to integrate this and the associated edge cases this brought up a number of issues deeper in the system (some of which I had previously implemented). The ImageStorageService is doing way too much, and we have a need to be able to retrieve sessions the session given image/latents id, which is not currently feasible due to SQLite's JSON parsing performance.
|
||||
|
||||
I made a new ResultsService and `results` table in the db to facilitate this. This first attempt failed because it doesn't handle uploads and leaves the codebase messy.
|
||||
|
||||
So I've spent the day trying to figure out to handle this in a sane way and think I've got something decent. I've described some changes to service bases and the database below.
|
||||
|
||||
The gist of it is to store the core parameters for an image in its metadata when the image is saved, but never to read from it. Instead, the same metadata is stored in the database, which will be set up for efficient access. So when a page of images is requested, the metadata comes from the db instead of a filesystem operation.
|
||||
|
||||
The URL generation responsibilities have been split off the image storage service in to a URL service. New database services/tables for images and tensor are added. These services will provide paginated images/tensors for the API to serve. This also paves the way for handling tensors as first-class outputs.
|
||||
"""
|
||||
|
||||
|
||||
# TODO: Make a new model for this
|
||||
class ResourceOrigin(str, Enum):
|
||||
"""The origin of a resource (eg image or tensor)."""
|
||||
|
||||
RESULTS = "results"
|
||||
UPLOADS = "uploads"
|
||||
INTERMEDIATES = "intermediates"
|
||||
|
||||
|
||||
class ImageKind(str, Enum):
|
||||
"""The kind of an image."""
|
||||
|
||||
IMAGE = "image"
|
||||
CONTROL_IMAGE = "control_image"
|
||||
|
||||
|
||||
class TensorKind(str, Enum):
|
||||
"""The kind of a tensor."""
|
||||
|
||||
IMAGE_TENSOR = "tensor"
|
||||
CONDITIONING = "conditioning"
|
||||
|
||||
|
||||
"""
|
||||
Core Generation Metadata Pydantic Model
|
||||
|
||||
I've already implemented the code to traverse a session to build this object.
|
||||
"""
|
||||
|
||||
|
||||
class CoreGenerationMetadata(BaseModel):
|
||||
"""Core generation metadata for an image/tensor generated in InvokeAI.
|
||||
|
||||
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
|
||||
|
||||
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||
"""
|
||||
|
||||
positive_conditioning: Optional[str] = Field(
|
||||
description="The positive conditioning."
|
||||
)
|
||||
negative_conditioning: Optional[str] = Field(
|
||||
description="The negative conditioning."
|
||||
)
|
||||
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
|
||||
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
|
||||
seed: Optional[int] = Field(description="The seed used for noise generation.")
|
||||
cfg_scale: Optional[float] = Field(
|
||||
description="The classifier-free guidance scale."
|
||||
)
|
||||
steps: Optional[int] = Field(description="The number of steps used for inference.")
|
||||
scheduler: Optional[str] = Field(description="The scheduler used for inference.")
|
||||
model: Optional[str] = Field(description="The model used for inference.")
|
||||
strength: Optional[float] = Field(
|
||||
description="The strength used for image-to-image/tensor-to-tensor."
|
||||
)
|
||||
image: Optional[str] = Field(description="The ID of the initial image.")
|
||||
tensor: Optional[str] = Field(description="The ID of the initial tensor.")
|
||||
# Pending model refactor:
|
||||
# vae: Optional[str] = Field(description="The VAE used for decoding.")
|
||||
# unet: Optional[str] = Field(description="The UNet used dor inference.")
|
||||
# clip: Optional[str] = Field(description="The CLIP Encoder used for conditioning.")
|
||||
|
||||
|
||||
"""
|
||||
Minimal Uploads Metadata Model
|
||||
"""
|
||||
|
||||
|
||||
class UploadsMetadata(BaseModel):
|
||||
"""Limited metadata for an uploaded image/tensor."""
|
||||
|
||||
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
|
||||
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
|
||||
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
|
||||
# from another SD application or InvokeAI, so we need to make it very flexible. I think it's
|
||||
# best to just store it as a string and let the frontend parse it.
|
||||
# If the upload is a tensor type, this will be omitted.
|
||||
extra: Optional[str] = Field(
|
||||
description="Extra metadata, extracted from the PNG tEXt chunk."
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Slimmed-down Image Storage Service Base
|
||||
- No longer lists images or generates URLs - only stores and retrieves images.
|
||||
- OSS implementation for disk storage
|
||||
"""
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImage,
|
||||
image_kind: ImageKind,
|
||||
origin: ResourceOrigin,
|
||||
context_id: str,
|
||||
node_id: str,
|
||||
metadata: CoreGenerationMetadata,
|
||||
) -> str:
|
||||
"""Saves an image and its thumbnail, returning its unique identifier."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str, thumbnail: bool = False) -> Union[PILImage, None]:
|
||||
"""Retrieves an image as a PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
|
||||
class TensorStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving tensors."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
tensor: Tensor,
|
||||
tensor_kind: TensorKind,
|
||||
origin: ResourceOrigin,
|
||||
context_id: str,
|
||||
node_id: str,
|
||||
metadata: CoreGenerationMetadata,
|
||||
) -> str:
|
||||
"""Saves a tensor, returning its unique identifier."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str, thumbnail: bool = False) -> Union[Tensor, None]:
|
||||
"""Retrieves a tensor as a torch Tensor."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes a tensor."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
New Url Service Base
|
||||
- Abstracts the logic for generating URLs out of the storage service
|
||||
- OSS implementation for locally-hosted URLs
|
||||
- Also provides a method to get the internal path to a resource (for OSS, the FS path)
|
||||
"""
|
||||
|
||||
|
||||
class ResourceLocationServiceBase(ABC):
|
||||
"""Responsible for locating resources (eg images or tensors)."""
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, id: str) -> str:
|
||||
"""Gets the URL for a resource."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, id: str) -> str:
|
||||
"""Gets the path for a resource."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
New Images Database Service Base
|
||||
|
||||
This is a new service that will be responsible for the new `images` table(s):
|
||||
- Storing images in the table
|
||||
- Retrieving individual images and pages of images
|
||||
- Deleting individual images
|
||||
|
||||
Operations will typically use joins with the various `images` tables.
|
||||
"""
|
||||
|
||||
|
||||
class ImagesDbServiceBase(ABC):
|
||||
"""Responsible for interfacing with `images` table."""
|
||||
|
||||
class GeneratedImageEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the image.")
|
||||
session_id: str = Field(description="The session ID.")
|
||||
node_id: str = Field(description="The node ID.")
|
||||
metadata: CoreGenerationMetadata = Field(
|
||||
description="The metadata for the image."
|
||||
)
|
||||
|
||||
class UploadedImageEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the image.")
|
||||
metadata: UploadsMetadata = Field(description="The metadata for the image.")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Union[GeneratedImageEntity, UploadedImageEntity, None]:
|
||||
"""Gets an image from the `images` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, image_kind: ImageKind, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[Union[GeneratedImageEntity, UploadedImageEntity]]:
|
||||
"""Gets a page of images from the `images` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes an image from the `images` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
id: str,
|
||||
image_kind: ImageKind,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: CoreGenerationMetadata | UploadsMetadata,
|
||||
) -> None:
|
||||
"""Sets an image in the `images` table."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
New Tensor Database Service Base
|
||||
|
||||
This is a new service that will be responsible for the new `tensor` table:
|
||||
- Storing tensor in the table
|
||||
- Retrieving individual tensor and pages of tensor
|
||||
- Deleting individual tensor
|
||||
|
||||
Operations will always use joins with the `tensor_metadata` table.
|
||||
"""
|
||||
|
||||
|
||||
class TensorDbServiceBase(ABC):
|
||||
"""Responsible for interfacing with `tensor` table."""
|
||||
|
||||
class GeneratedTensorEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the tensor.")
|
||||
session_id: str = Field(description="The session ID.")
|
||||
node_id: str = Field(description="The node ID.")
|
||||
metadata: CoreGenerationMetadata = Field(
|
||||
description="The metadata for the tensor."
|
||||
)
|
||||
|
||||
class UploadedTensorEntity(BaseModel):
|
||||
id: str = Field(description="The unique identifier for the tensor.")
|
||||
metadata: UploadsMetadata = Field(description="The metadata for the tensor.")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: str) -> Union[GeneratedTensorEntity, UploadedTensorEntity, None]:
|
||||
"""Gets a tensor from the `tensor` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, tensor_kind: TensorKind, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[Union[GeneratedTensorEntity, UploadedTensorEntity]]:
|
||||
"""Gets a page of tensor from the `tensor` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str) -> None:
|
||||
"""Deletes a tensor from the `tensor` table."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
id: str,
|
||||
tensor_kind: TensorKind,
|
||||
session_id: Optional[str],
|
||||
node_id: Optional[str],
|
||||
metadata: CoreGenerationMetadata | UploadsMetadata,
|
||||
) -> None:
|
||||
"""Sets a tensor in the `tensor` table."""
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
Database Changes
|
||||
|
||||
The existing tables will remain as-is, new tables will be added.
|
||||
|
||||
Tensor now also have the same types as images - `results`, `intermediates`, `uploads`. Storage, retrieval, and operations may diverge from images in the future, so they are managed separately.
|
||||
|
||||
A few `images` tables are created to store all images:
|
||||
- `results` and `intermediates` images have additional data: `session_id` and `node_id`, and may be further differentiated in the future. For this reason, they each get their own table.
|
||||
- `uploads` do not get their own table, as they are never going to have more than an `id`, `image_kind` and `timestamp`.
|
||||
- `images_metadata` holds the same image metadata that is written to the image. This table, along with the URL service, allow us to more efficiently serve images without having to read the image from storage.
|
||||
|
||||
The same tables are made for `tensor` and for the moment, implementation is expected to be identical.
|
||||
|
||||
Schemas for each table below.
|
||||
|
||||
Insertions and updates of ancillary tables (e.g. `results_images`, `images_metadata`, etc) will need to be done manually in the services, but should be straightforward. Deletion via cascading will be handled by the database.
|
||||
"""
|
||||
|
||||
|
||||
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
|
||||
"""
|
||||
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
|
||||
"""
|
||||
|
||||
delimiter = ", "
|
||||
values = [f"('{e.value}')" for e in enum]
|
||||
return delimiter.join(values)
|
||||
|
||||
|
||||
def create_sql_table_from_enum(
|
||||
enum: Type[Enum],
|
||||
table_name: str,
|
||||
primary_key_name: str,
|
||||
cursor: sqlite3.Cursor,
|
||||
lock: threading.Lock,
|
||||
):
|
||||
"""
|
||||
Creates and populates a table to be used as a functional enum.
|
||||
"""
|
||||
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
values_string = create_sql_values_string_from_string_enum(enum)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{primary_key_name} TEXT PRIMARY KEY
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`resource_origins` functions as an enum for the ResourceOrigin model.
|
||||
"""
|
||||
|
||||
|
||||
def create_resource_origins_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
create_sql_table_from_enum(
|
||||
enum=ResourceOrigin,
|
||||
table_name="resource_origins",
|
||||
primary_key_name="origin_name",
|
||||
cursor=cursor,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
`image_kinds` functions as an enum for the ImageType model.
|
||||
"""
|
||||
|
||||
|
||||
def create_image_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
create_sql_table_from_enum(
|
||||
enum=ImageKind,
|
||||
table_name="image_kinds",
|
||||
primary_key_name="kind_name",
|
||||
cursor=cursor,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
`tensor_kinds` functions as an enum for the TensorType model.
|
||||
"""
|
||||
|
||||
|
||||
def create_tensor_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
create_sql_table_from_enum(
|
||||
enum=TensorKind,
|
||||
table_name="tensor_kinds",
|
||||
primary_key_name="kind_name",
|
||||
cursor=cursor,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
`images` stores all images, regardless of type
|
||||
"""
|
||||
|
||||
|
||||
def create_images_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
id TEXT PRIMARY KEY,
|
||||
origin TEXT,
|
||||
image_kind TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
|
||||
FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`image_results` stores additional data specific to `results` images.
|
||||
"""
|
||||
|
||||
|
||||
def create_image_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS image_results (
|
||||
images_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_results_images_id ON image_results(id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`image_intermediates` stores additional data specific to `intermediates` images
|
||||
"""
|
||||
|
||||
|
||||
def create_image_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS image_intermediates (
|
||||
images_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_intermediates_images_id ON image_intermediates(id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
"""
|
||||
`images_metadata` stores basic metadata for any image type
|
||||
"""
|
||||
|
||||
|
||||
def create_images_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images_metadata (
|
||||
images_id TEXT PRIMARY KEY,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `tensor` table: stores references to tensor
|
||||
|
||||
|
||||
def create_tensors_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensors (
|
||||
id TEXT PRIMARY KEY,
|
||||
origin TEXT,
|
||||
tensor_kind TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
|
||||
FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name),
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `results_tensor` stores additional data specific to `result` tensor
|
||||
|
||||
|
||||
def create_tensor_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensor_results (
|
||||
tensor_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_results_tensor_id ON tensor_results(tensor_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `tensor_intermediates` stores additional data specific to `intermediate` tensor
|
||||
|
||||
|
||||
def create_tensor_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensor_intermediates (
|
||||
tensor_id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
node_id TEXT NOT NULL,
|
||||
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_intermediates_tensor_id ON tensor_intermediates(tensor_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
# `tensors_metadata` table: stores generated/transformed metadata for tensor
|
||||
|
||||
|
||||
def create_tensors_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
|
||||
try:
|
||||
lock.acquire()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tensors_metadata (
|
||||
tensor_id TEXT PRIMARY KEY,
|
||||
metadata TEXT,
|
||||
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensor_id ON tensors_metadata(tensor_id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
lock.release()
|
||||
112
invokeai/app/services/restoration_services.py
Normal file
112
invokeai/app/services/restoration_services.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
# This should be a real base class for postprocessing functions,
|
||||
# but right now we just instantiate the existing gfpgan, esrgan
|
||||
# and codeformer functions.
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
restoration = Restoration()
|
||||
if args.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
# TO DO: refactor into separate methods
|
||||
def upscale_and_reconstruct(
|
||||
self,
|
||||
image_list,
|
||||
facetool="gfpgan",
|
||||
upscale=None,
|
||||
upscale_denoise_str=0.75,
|
||||
strength=0.0,
|
||||
codeformer_fidelity=0.75,
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
prefix=None,
|
||||
):
|
||||
results = []
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
try:
|
||||
if strength > 0:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
||||
)
|
||||
image = self.codeformer.process(
|
||||
image=image,
|
||||
strength=strength,
|
||||
device=cf_device,
|
||||
seed=seed,
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = self.esrgan.process(
|
||||
image,
|
||||
upscale[1],
|
||||
seed,
|
||||
int(upscale[0]),
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
return results
|
||||
466
invokeai/app/services/results.py
Normal file
466
invokeai/app/services/results.py
Normal file
@@ -0,0 +1,466 @@
|
||||
from enum import Enum
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.services.graph import Edge, GraphExecutionState
|
||||
from invokeai.app.invocations.latent import LatentsOutput
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
class ResultType(str, Enum):
|
||||
image_output = "image_output"
|
||||
latents_output = "latents_output"
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
"""A session result"""
|
||||
|
||||
id: str = Field(description="Result ID")
|
||||
session_id: str = Field(description="Session ID")
|
||||
node_id: str = Field(description="Node ID")
|
||||
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
|
||||
|
||||
|
||||
class ResultWithSession(BaseModel):
|
||||
"""A result with its session"""
|
||||
|
||||
result: Result = Field(description="The result")
|
||||
session: GraphExecutionState = Field(description="The session")
|
||||
|
||||
|
||||
# Create a directed graph
|
||||
from typing import Any, TypedDict, Union
|
||||
from networkx import DiGraph
|
||||
import networkx as nx
|
||||
import json
|
||||
|
||||
|
||||
# We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
|
||||
# model used by the system, because we may be a graph in an old format. We can, however, use the
|
||||
# Edge model, because the edge format does not change.
|
||||
class LooseGraph(BaseModel):
|
||||
id: str
|
||||
nodes: dict[str, dict[str, Any]]
|
||||
edges: list[Edge]
|
||||
|
||||
|
||||
# An intermediate type used during parsing
|
||||
class NearestAncestor(TypedDict):
|
||||
node_id: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
# The ancestor types that contain the core metadata
|
||||
ANCESTOR_TYPES = ['t2l', 'l2l']
|
||||
|
||||
# The core metadata parameters in the ancestor types
|
||||
ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
|
||||
|
||||
# The core metadata parameters in the noise node
|
||||
NOISE_FIELDS = ['seed', 'width', 'height']
|
||||
|
||||
# Find nearest t2l or l2l ancestor from a given l2i node
|
||||
def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
|
||||
"""Returns metadata for the nearest ancestor of a given node.
|
||||
|
||||
Parameters:
|
||||
G (DiGraph): A directed graph.
|
||||
node_id (str): The ID of the starting node.
|
||||
|
||||
Returns:
|
||||
NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
|
||||
"""
|
||||
|
||||
# Retrieve the node from the graph
|
||||
node = G.nodes[node_id]
|
||||
|
||||
# If the node type is one of the core metadata node types, gather necessary metadata and return
|
||||
if node.get('type') in ANCESTOR_TYPES:
|
||||
parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
|
||||
return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
|
||||
|
||||
|
||||
# Else, look for the ancestor in the predecessor nodes
|
||||
for predecessor in G.predecessors(node_id):
|
||||
result = find_nearest_ancestor(G, predecessor)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# If there are no valid ancestors, return None
|
||||
return None
|
||||
|
||||
|
||||
def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
|
||||
"""Collects additional metadata from nodes connected to a given node.
|
||||
|
||||
Parameters:
|
||||
graph (LooseGraph): The graph.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict | None: A dictionary containing additional metadata.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
|
||||
# Iterate over all edges in the graph
|
||||
for edge in graph.edges:
|
||||
dest_node_id = edge.destination.node_id
|
||||
dest_field = edge.destination.field
|
||||
source_node = graph.nodes[edge.source.node_id]
|
||||
|
||||
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||
if dest_node_id == node_id:
|
||||
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
|
||||
if dest_field == 'positive_conditioning':
|
||||
metadata['positive_conditioning'] = source_node.get('prompt')
|
||||
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
|
||||
if dest_field == 'negative_conditioning':
|
||||
metadata['negative_conditioning'] = source_node.get('prompt')
|
||||
# If the destination field is 'noise', add the core noise fields from the source node
|
||||
if dest_field == 'noise':
|
||||
for field in NOISE_FIELDS:
|
||||
metadata[field] = source_node.get(field)
|
||||
return metadata
|
||||
|
||||
def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
|
||||
"""Builds the core metadata for a given node.
|
||||
|
||||
Parameters:
|
||||
graph_raw (str): The graph structure as a raw string.
|
||||
node_id (str): The ID of the node.
|
||||
|
||||
Returns:
|
||||
dict | None: A dictionary containing core metadata.
|
||||
"""
|
||||
|
||||
# Create a directed graph to facilitate traversal
|
||||
G = nx.DiGraph()
|
||||
|
||||
# Convert the raw graph string into a JSON object
|
||||
graph = parse_obj_as(LooseGraph, graph_raw)
|
||||
|
||||
# Add nodes and edges to the graph
|
||||
for node_id, node_data in graph.nodes.items():
|
||||
G.add_node(node_id, **node_data)
|
||||
for edge in graph.edges:
|
||||
G.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
|
||||
# Find the nearest ancestor of the given node
|
||||
ancestor = find_nearest_ancestor(G, node_id)
|
||||
|
||||
# If no ancestor was found, return None
|
||||
if ancestor is None:
|
||||
return None
|
||||
|
||||
metadata = ancestor['metadata']
|
||||
ancestor_id = ancestor['node_id']
|
||||
|
||||
# Get additional metadata related to the ancestor
|
||||
addl_metadata = get_additional_metadata(graph, ancestor_id)
|
||||
|
||||
# If additional metadata was found, add it to the main metadata
|
||||
if addl_metadata is not None:
|
||||
metadata.update(addl_metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
|
||||
class ResultsServiceABC(ABC):
|
||||
"""The Results service is responsible for retrieving results."""
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, result_id: str, result_type: ResultType
|
||||
) -> Union[ResultWithSession, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, result_type: ResultType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SqliteResultsService(ResultsServiceABC):
|
||||
"""SQLite implementation of the Results service."""
|
||||
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS results (
|
||||
id TEXT PRIMARY KEY, -- the result's name
|
||||
result_type TEXT, -- `image_output` | `latents_output`
|
||||
node_id TEXT, -- the node that produced this result
|
||||
session_id TEXT, -- the session that produced this result
|
||||
created_at INTEGER, -- the time at which this result was created
|
||||
data TEXT -- the result itself
|
||||
);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_result_id ON results(id);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_joined_result(self, result_row: Any, column_names: list[str]):
|
||||
result_raw = {}
|
||||
session_raw = {}
|
||||
|
||||
for idx, name in enumerate(column_names):
|
||||
if name == "session":
|
||||
session_raw = json.loads(result_row[idx])
|
||||
elif name == "data":
|
||||
result_raw[name] = json.loads(result_row[idx])
|
||||
else:
|
||||
result_raw[name] = result_row[idx]
|
||||
|
||||
graph_raw = session_raw['execution_graph']
|
||||
|
||||
result = parse_obj_as(Result, result_raw)
|
||||
session = parse_obj_as(GraphExecutionState, session_raw)
|
||||
|
||||
m = build_core_metadata(graph_raw, result.node_id)
|
||||
print(m)
|
||||
|
||||
# g = session.execution_graph.nx_graph()
|
||||
# ancestors = nx.dag.ancestors(g, result.node_id)
|
||||
|
||||
# nodes = [session.execution_graph.get_node(result.node_id)]
|
||||
# for ancestor in ancestors:
|
||||
# nodes.append(session.execution_graph.get_node(ancestor))
|
||||
|
||||
# filtered_nodes = filter(lambda n: n.type in NODE_TYPE_ALLOWLIST, nodes)
|
||||
# print(list(map(lambda n: n.dict(), filtered_nodes)))
|
||||
# metadata = {}
|
||||
# for node in nodes:
|
||||
# if (node.type in ['txt2img', 'img2img',])
|
||||
# for field, value in node.dict().items():
|
||||
# if field not in ['type', 'id']:
|
||||
# if field not in metadata:
|
||||
# metadata[field] = value
|
||||
|
||||
# print(ancestors)
|
||||
# print(nodes)
|
||||
# print(metadata)
|
||||
|
||||
# for node in nodes:
|
||||
# print(node.dict())
|
||||
|
||||
# print(nodes)
|
||||
|
||||
return ResultWithSession(
|
||||
result=result,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def get(
|
||||
self, result_id: str, result_type: ResultType
|
||||
) -> Union[ResultWithSession, None]:
|
||||
"""Retrieves a result by ID and type."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT
|
||||
results.id AS id,
|
||||
results.result_type AS result_type,
|
||||
results.node_id AS node_id,
|
||||
results.session_id AS session_id,
|
||||
results.data AS data,
|
||||
graph_executions.item AS session
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE results.id = ? AND results.result_type = ?
|
||||
""",
|
||||
(result_id, result_type),
|
||||
)
|
||||
|
||||
result_row = self._cursor.fetchone()
|
||||
|
||||
if result_row is None:
|
||||
return None
|
||||
|
||||
column_names = list(map(lambda x: x[0], self._cursor.description))
|
||||
result_parsed = self._parse_joined_result(result_row, column_names)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result_parsed:
|
||||
return None
|
||||
|
||||
return result_parsed
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
result_type: ResultType,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
"""Lists results of a given type."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT
|
||||
results.id AS id,
|
||||
results.result_type AS result_type,
|
||||
results.node_id AS node_id,
|
||||
results.session_id AS session_id,
|
||||
results.data AS data,
|
||||
graph_executions.item AS session
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE results.result_type = ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(result_type.value, per_page, page * per_page),
|
||||
)
|
||||
|
||||
result_rows = self._cursor.fetchall()
|
||||
column_names = list(map(lambda c: c[0], self._cursor.description))
|
||||
|
||||
result_parsed = []
|
||||
|
||||
for result_row in result_rows:
|
||||
result_parsed.append(
|
||||
self._parse_joined_result(result_row, column_names)
|
||||
)
|
||||
|
||||
self._cursor.execute("""SELECT count(*) FROM results;""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[ResultWithSession](
|
||||
items=result_parsed,
|
||||
page=page,
|
||||
pages=pageCount,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
) -> PaginatedResults[ResultWithSession]:
|
||||
"""Finds results by query."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT results.data, graph_executions.item
|
||||
FROM results
|
||||
JOIN graph_executions ON results.session_id = graph_executions.id
|
||||
WHERE item LIKE ?
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
|
||||
result_rows = self._cursor.fetchall()
|
||||
|
||||
items = list(
|
||||
map(
|
||||
lambda r: ResultWithSession(
|
||||
result=parse_raw_as(Result, r[0]),
|
||||
session=parse_raw_as(GraphExecutionState, r[1]),
|
||||
),
|
||||
result_rows,
|
||||
)
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*) FROM results WHERE item LIKE ?;
|
||||
""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[ResultWithSession](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
|
||||
"""Updates the results table with the results from the session."""
|
||||
with self._conn as conn:
|
||||
for node_id, result in session.results.items():
|
||||
# We'll only process 'image_output' or 'latents_output'
|
||||
if result.type not in ["image_output", "latents_output"]:
|
||||
continue
|
||||
|
||||
# The id depends on the result type
|
||||
if result.type == "image_output":
|
||||
id = result.image.image_name
|
||||
result_type = "image_output"
|
||||
else:
|
||||
id = result.latents.latents_name
|
||||
result_type = "latents_output"
|
||||
|
||||
# Insert the result into the results table, ignoring if it already exists
|
||||
conn.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO results (id, result_type, node_id, session_id, created_at, data)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
id,
|
||||
result_type,
|
||||
node_id,
|
||||
session.id,
|
||||
get_timestamp(),
|
||||
result.json(),
|
||||
),
|
||||
)
|
||||
@@ -1,12 +1,15 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
sqlite_memory = ':memory:'
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
@@ -16,15 +19,17 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = 'id'):
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(self._filename, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@@ -32,10 +37,14 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);''')
|
||||
self._cursor.execute(f'''CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);''')
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@@ -46,7 +55,11 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);''', (item.json(),))
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
@@ -54,7 +67,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -67,7 +82,10 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''DELETE FROM {self._table_name} WHERE id = ?;''', (str(id),))
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
@@ -75,12 +93,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;''', (per_page, page * per_page))
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name};''')
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -88,22 +109,26 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f'''SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;''', (f'%{query}%', per_page, page * per_page))
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
self._cursor.execute(f'''SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;''', (f'%{query}%',))
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -111,9 +136,5 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items = items,
|
||||
page = page,
|
||||
pages = pageCount,
|
||||
per_page = per_page,
|
||||
total = count
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
30
invokeai/app/services/urls.py
Normal file
30
invokeai/app/services/urls.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}"
|
||||
39
invokeai/app/services/util/create_enum_table.py
Normal file
39
invokeai/app/services/util/create_enum_table.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from enum import Enum
|
||||
import sqlite3
|
||||
from typing import Type
|
||||
|
||||
|
||||
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
|
||||
"""
|
||||
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
|
||||
"""
|
||||
|
||||
delimiter = ", "
|
||||
values = [f"('{e.value}')" for e in enum]
|
||||
return delimiter.join(values)
|
||||
|
||||
|
||||
def create_enum_table(
|
||||
enum: Type[Enum],
|
||||
table_name: str,
|
||||
primary_key_name: str,
|
||||
cursor: sqlite3.Cursor,
|
||||
):
|
||||
"""
|
||||
Creates and populates a table to be used as a functional enum.
|
||||
"""
|
||||
|
||||
values_string = create_sql_values_string_from_string_enum(enum)
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
{primary_key_name} TEXT PRIMARY KEY
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
|
||||
"""
|
||||
)
|
||||
12
invokeai/app/util/enum.py
Normal file
12
invokeai/app/util/enum.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class MetaEnum(EnumMeta):
|
||||
"""Metaclass to support `in` syntax value checking in String Enums"""
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
21
invokeai/app/util/misc.py
Normal file
21
invokeai/app/util/misc.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
def get_iso_timestamp() -> str:
|
||||
return datetime.datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
55
invokeai/app/util/step_callback.py
Normal file
55
invokeai/app/util/step_callback.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
context: InvocationContext,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be. Use
|
||||
# that estimate if it is available.
|
||||
if intermediate_state.predicted_original is not None:
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user