mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 16:57:58 -05:00
Compare commits
628 Commits
v3.0.1rc2
...
bugfix/mak
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86c11f9e27 | ||
|
|
832335998f | ||
|
|
1102c12084 | ||
|
|
b5cee7d20c | ||
|
|
89b82b3dc4 | ||
|
|
8923201fdf | ||
|
|
226409107b | ||
|
|
ae986bf873 | ||
|
|
daf75a1361 | ||
|
|
fe4b2d53ed | ||
|
|
c39f8b478b | ||
|
|
1f82d8013e | ||
|
|
e373bfca54 | ||
|
|
2ca8611723 | ||
|
|
b12cf315a8 | ||
|
|
975586bb40 | ||
|
|
a7ba142ad9 | ||
|
|
0d36bab6cc | ||
|
|
c2e7f62701 | ||
|
|
1f194e3688 | ||
|
|
f9b8b5cff2 | ||
|
|
f7c92e1eff | ||
|
|
70b8c3dfea | ||
|
|
43b30355e4 | ||
|
|
a93bd01353 | ||
|
|
bb1b8ceaa8 | ||
|
|
be8edaf3fd | ||
|
|
9cbaefaa81 | ||
|
|
cc7c6e5d41 | ||
|
|
f2ee8a3da8 | ||
|
|
e98d7a52d4 | ||
|
|
21e1c0a5f0 | ||
|
|
611e241ca7 | ||
|
|
6df4af2c79 | ||
|
|
0f8606914e | ||
|
|
5b1099193d | ||
|
|
230131646f | ||
|
|
8b1ec2685f | ||
|
|
60c2c877d7 | ||
|
|
315a056686 | ||
|
|
80b0c5eab4 | ||
|
|
08dc265e09 | ||
|
|
029a95550e | ||
|
|
ee6a26a97d | ||
|
|
a512fdc0f6 | ||
|
|
767a612746 | ||
|
|
0a71d6baa1 | ||
|
|
37be827e17 | ||
|
|
04a9894e77 | ||
|
|
f9958de6be | ||
|
|
ec10aca91e | ||
|
|
2b7dd3e236 | ||
|
|
fa884134d9 | ||
|
|
18006cab9a | ||
|
|
75ea716c13 | ||
|
|
d5f7027597 | ||
|
|
b1ad777f5a | ||
|
|
f65c8092cb | ||
|
|
94bfef3543 | ||
|
|
c48fd9c083 | ||
|
|
f49fc7fb55 | ||
|
|
a4b029d03c | ||
|
|
d6c9bf5b38 | ||
|
|
4f82273fc4 | ||
|
|
e54355f0f3 | ||
|
|
b2934be6ba | ||
|
|
eab67b6a01 | ||
|
|
02fa116690 | ||
|
|
5190a4c282 | ||
|
|
141d438517 | ||
|
|
549d2e0485 | ||
|
|
d3d8b71c67 | ||
|
|
6eaaa75a5d | ||
|
|
ba57ec5907 | ||
|
|
cd0e4bc1d7 | ||
|
|
9d3cd85bdd | ||
|
|
46a8eed33e | ||
|
|
9fee3f7b66 | ||
|
|
9217a217d4 | ||
|
|
b2700ffde4 | ||
|
|
511da59793 | ||
|
|
409e5d01ba | ||
|
|
58d5c61c79 | ||
|
|
3d8da67be3 | ||
|
|
957ee6d370 | ||
|
|
fecad2c014 | ||
|
|
550e6ef27a | ||
|
|
cc85c98bf3 | ||
|
|
75fb3f429f | ||
|
|
d63bb39475 | ||
|
|
096333ba3f | ||
|
|
0b2925709c | ||
|
|
7a8f14d595 | ||
|
|
59ba9fc0f6 | ||
|
|
6e0beb1ed4 | ||
|
|
94636ddb03 | ||
|
|
746e099f0d | ||
|
|
499e89d6f6 | ||
|
|
250d530260 | ||
|
|
90fa3eebb3 | ||
|
|
0aba105a8f | ||
|
|
9e2e82a752 | ||
|
|
561951ad98 | ||
|
|
3ff9961bda | ||
|
|
33779b6339 | ||
|
|
b35cdc05a5 | ||
|
|
9afb5d6ace | ||
|
|
50177b8ed9 | ||
|
|
c8864e475b | ||
|
|
fcf7f4ac77 | ||
|
|
29f1c6dc82 | ||
|
|
28208e6f49 | ||
|
|
c33acf951e | ||
|
|
500cd552bc | ||
|
|
55d27f71a3 | ||
|
|
746c7c59ff | ||
|
|
ad96c41156 | ||
|
|
27bd127fb0 | ||
|
|
f296e5c41e | ||
|
|
a67d8376c7 | ||
|
|
9f6221fe8c | ||
|
|
7587b54787 | ||
|
|
7254ffc3e7 | ||
|
|
6034fa12de | ||
|
|
ce3675fc14 | ||
|
|
8acd7eeca5 | ||
|
|
7293a6036a | ||
|
|
0b11f309ca | ||
|
|
6a8eb392b2 | ||
|
|
f343ab0302 | ||
|
|
824ca92760 | ||
|
|
d7d6298ec0 | ||
|
|
58a48bf197 | ||
|
|
5629d8fa37 | ||
|
|
1affb7f647 | ||
|
|
69a9dc7b36 | ||
|
|
f3ae52ff97 | ||
|
|
7479f9cc02 | ||
|
|
87ce4ab27c | ||
|
|
7c0023ad9e | ||
|
|
231e665675 | ||
|
|
80fd4c2176 | ||
|
|
3b6e425e17 | ||
|
|
50415450d8 | ||
|
|
06296896a9 | ||
|
|
a7399aca0c | ||
|
|
d1ea8b1e98 | ||
|
|
f851ad7ba0 | ||
|
|
591838a84b | ||
|
|
c0c2ab3dcf | ||
|
|
56023bc725 | ||
|
|
2ef6a8995b | ||
|
|
d0fee93aac | ||
|
|
1bfe9835cf | ||
|
|
8e7eae6cc7 | ||
|
|
f6522c8971 | ||
|
|
a969707e45 | ||
|
|
6c8e898f09 | ||
|
|
7bad9bcf53 | ||
|
|
d42b45116f | ||
|
|
d4812bbc8d | ||
|
|
3cd05cf6bf | ||
|
|
2564301aeb | ||
|
|
da0efeaa7f | ||
|
|
49cce1eec6 | ||
|
|
e9ec5ab85c | ||
|
|
17fed1c870 | ||
|
|
ade78b9591 | ||
|
|
c8fbaf54b6 | ||
|
|
f86d388786 | ||
|
|
cd2c688562 | ||
|
|
2d29ac6f0d | ||
|
|
2c2b731386 | ||
|
|
2f68a1a76c | ||
|
|
930e7bc754 | ||
|
|
7d4ace962a | ||
|
|
06842f8e0a | ||
|
|
c82da330db | ||
|
|
628df4ec98 | ||
|
|
16b956616f | ||
|
|
604cc17a3a | ||
|
|
37c9b85549 | ||
|
|
8b39b67ec7 | ||
|
|
a933977861 | ||
|
|
dfb41d8461 | ||
|
|
e98f7eda2e | ||
|
|
b4a74f6523 | ||
|
|
f7aec3b934 | ||
|
|
4d5169e16d | ||
|
|
a7e44678fb | ||
|
|
da0184a786 | ||
|
|
f56f19710d | ||
|
|
96b7248051 | ||
|
|
e77400ab62 | ||
|
|
13347f6aec | ||
|
|
a9bf387e5e | ||
|
|
8258c87a9f | ||
|
|
1b1b399fd0 | ||
|
|
a8d3e078c0 | ||
|
|
6ed7ba57dd | ||
|
|
2b3b77a276 | ||
|
|
8b8ec68b30 | ||
|
|
e20af5aef0 | ||
|
|
57e8ec9488 | ||
|
|
734a9e4271 | ||
|
|
fe924daee3 | ||
|
|
750f09fbed | ||
|
|
4df581811e | ||
|
|
eb70bc2ae4 | ||
|
|
5f29526a8e | ||
|
|
492bfe002a | ||
|
|
809705c30d | ||
|
|
f0918edf98 | ||
|
|
a846d82fa1 | ||
|
|
22f7cf0638 | ||
|
|
25c669b1d6 | ||
|
|
4367061b19 | ||
|
|
0fd13d3604 | ||
|
|
72a3e776b2 | ||
|
|
af044007d5 | ||
|
|
1db2c93f75 | ||
|
|
f272a44feb | ||
|
|
2539e26c18 | ||
|
|
b0738b7f70 | ||
|
|
8469d3e95a | ||
|
|
ae17d01e1d | ||
|
|
f3d3316558 | ||
|
|
5a6cefb0ea | ||
|
|
1a6f5f0860 | ||
|
|
5bfd6cb66f | ||
|
|
59caff7ff0 | ||
|
|
6487e7d906 | ||
|
|
77033eabd3 | ||
|
|
b80abdd101 | ||
|
|
006d782cc8 | ||
|
|
d09dfc3e9b | ||
|
|
66f524cae7 | ||
|
|
9ba50130a1 | ||
|
|
d4cf2d2666 | ||
|
|
9aaf67c5b4 | ||
|
|
b8b589c150 | ||
|
|
d93900a8de | ||
|
|
7f4c387080 | ||
|
|
80876bbbd1 | ||
|
|
7a4ff4c089 | ||
|
|
44bf308192 | ||
|
|
12e51c84ae | ||
|
|
b2eb83deff | ||
|
|
0ccc3b509e | ||
|
|
4043a4c21c | ||
|
|
c8ceb96091 | ||
|
|
83f75750a9 | ||
|
|
dc96a3e79d | ||
|
|
c076f1397e | ||
|
|
2568aafc0b | ||
|
|
65ed224bfc | ||
|
|
b6e369c745 | ||
|
|
ecabfc252b | ||
|
|
da96a41103 | ||
|
|
d162b78767 | ||
|
|
eb6c317f04 | ||
|
|
6d7223238f | ||
|
|
8607d124c5 | ||
|
|
23497bf759 | ||
|
|
b10cf20eb1 | ||
|
|
3d93851dba | ||
|
|
9bacd77a79 | ||
|
|
1b158f62c4 | ||
|
|
6ad565d84c | ||
|
|
04229082d6 | ||
|
|
03c27412f7 | ||
|
|
f0613bb0ef | ||
|
|
0e9f92b868 | ||
|
|
7d0cc6ec3f | ||
|
|
2f8b928486 | ||
|
|
0d3c27f46c | ||
|
|
cff91f06d3 | ||
|
|
1d5d187ba1 | ||
|
|
1ac14a1e43 | ||
|
|
cfc3a20565 | ||
|
|
05ae4e283c | ||
|
|
f06fee4581 | ||
|
|
9091e19de8 | ||
|
|
0a0b7141af | ||
|
|
1deca89fde | ||
|
|
446fb4a438 | ||
|
|
ab5d938a1d | ||
|
|
9942af756a | ||
|
|
06742faca7 | ||
|
|
d2bddf7f91 | ||
|
|
91ebf9f76e | ||
|
|
bf94412d14 | ||
|
|
e080fd1e08 | ||
|
|
eeef1e08f8 | ||
|
|
b3b94b5a8d | ||
|
|
5c9787c145 | ||
|
|
cf72eba15c | ||
|
|
a6f9396a30 | ||
|
|
118d5b387b | ||
|
|
02d2cc758d | ||
|
|
db545f8801 | ||
|
|
b0d72b15b3 | ||
|
|
4e0949fa55 | ||
|
|
f028342f5b | ||
|
|
7021467048 | ||
|
|
26ef5249b1 | ||
|
|
87424be95d | ||
|
|
366952f810 | ||
|
|
450e95de59 | ||
|
|
0ba8a0ea6c | ||
|
|
f4981f26d5 | ||
|
|
6bc21984c6 | ||
|
|
43d6312587 | ||
|
|
0d125bf3e4 | ||
|
|
921ccad04d | ||
|
|
05c9207e7b | ||
|
|
3fc789a7ee | ||
|
|
008362918e | ||
|
|
8fc75a71ee | ||
|
|
82d259f43b | ||
|
|
ec48779080 | ||
|
|
bc20fe4cb5 | ||
|
|
5de42be4a6 | ||
|
|
818c55cd53 | ||
|
|
0db1e97119 | ||
|
|
29ac252501 | ||
|
|
880727436c | ||
|
|
77c5c18542 | ||
|
|
ed76250dba | ||
|
|
4d22cafdad | ||
|
|
1f9e984b0d | ||
|
|
8a4e5f73aa | ||
|
|
4599575e65 | ||
|
|
242d860a47 | ||
|
|
0c1a7e72d4 | ||
|
|
11a44b944d | ||
|
|
fd7b842419 | ||
|
|
5998509888 | ||
|
|
403a6e88f2 | ||
|
|
c9d452b9d4 | ||
|
|
dcc274a2b9 | ||
|
|
f404669831 | ||
|
|
ce687b28ef | ||
|
|
7292d89108 | ||
|
|
41d6a38690 | ||
|
|
fb8f218901 | ||
|
|
437f45a97f | ||
|
|
13ef33ed64 | ||
|
|
86d8b46fca | ||
|
|
e86925d424 | ||
|
|
df53b62048 | ||
|
|
55d3f04476 | ||
|
|
72ebe2ce68 | ||
|
|
7cd8b2f207 | ||
|
|
52437205bb | ||
|
|
ceebb501a4 | ||
|
|
cbe874b964 | ||
|
|
e2e5918ee2 | ||
|
|
1b131e328a | ||
|
|
81654daed7 | ||
|
|
746afcd235 | ||
|
|
ae0f4efcca | ||
|
|
23647336ce | ||
|
|
4ca54dd5fa | ||
|
|
d3a3067164 | ||
|
|
aeac557c41 | ||
|
|
af4fd328a6 | ||
|
|
c40c7424b6 | ||
|
|
a6b907150b | ||
|
|
bacdf985f1 | ||
|
|
e3519052ae | ||
|
|
b0e84c6497 | ||
|
|
f784e8412c | ||
|
|
1bafbafdd3 | ||
|
|
f5ac73b091 | ||
|
|
eb642653cb | ||
|
|
2c07f54b6e | ||
|
|
0691e0a12a | ||
|
|
79afcbd07e | ||
|
|
f4ead5e07f | ||
|
|
6d24ca7f52 | ||
|
|
2164da8592 | ||
|
|
adfd1e52f4 | ||
|
|
0e48c98330 | ||
|
|
4121c261a0 | ||
|
|
99823d5039 | ||
|
|
0abceb0e7b | ||
|
|
83d3f2347e | ||
|
|
73e25d8dbe | ||
|
|
50e00feceb | ||
|
|
03594c949a | ||
|
|
adb85036e6 | ||
|
|
7d7a9273ed | ||
|
|
f17ad227cf | ||
|
|
f91d01eb38 | ||
|
|
adfcb610b6 | ||
|
|
cafcd16657 | ||
|
|
2537ff0280 | ||
|
|
0f5f08e494 | ||
|
|
e20c4dc1e8 | ||
|
|
6dc4ddef1b | ||
|
|
26af5ec341 | ||
|
|
10b182f316 | ||
|
|
ac84a9f915 | ||
|
|
844578ab88 | ||
|
|
ff1c40747e | ||
|
|
dbfd1bcb5e | ||
|
|
444390617f | ||
|
|
6cb40d9d7b | ||
|
|
ca895a9cd0 | ||
|
|
7d27c7b1a4 | ||
|
|
6c82229910 | ||
|
|
43b1eb8e84 | ||
|
|
b10b07220e | ||
|
|
c2eb50d1cd | ||
|
|
73f3b7f84b | ||
|
|
bb18251fad | ||
|
|
348bee8981 | ||
|
|
078b33bda2 | ||
|
|
e82eb0b9fc | ||
|
|
ad976e5198 | ||
|
|
0e28961e69 | ||
|
|
6ce059f063 | ||
|
|
1de783b1ce | ||
|
|
3f9105be50 | ||
|
|
781322a647 | ||
|
|
9a1cfadd8b | ||
|
|
2a2d988928 | ||
|
|
ccceb32a85 | ||
|
|
72c519c6ad | ||
|
|
af12f67948 | ||
|
|
60f5606c2d | ||
|
|
24b19166dd | ||
|
|
0396bce4f9 | ||
|
|
71768f5988 | ||
|
|
0fb7328022 | ||
|
|
99daa97978 | ||
|
|
21617e60e1 | ||
|
|
982a568349 | ||
|
|
d79d5a4ff7 | ||
|
|
9968ff2893 | ||
|
|
35dd58e273 | ||
|
|
6d82a1019a | ||
|
|
6ed1bf7084 | ||
|
|
974175be45 | ||
|
|
86b8b69e88 | ||
|
|
bc9a5038fd | ||
|
|
bee678fdd1 | ||
|
|
c5caf1e8fe | ||
|
|
b163ae6a4d | ||
|
|
dca685ac25 | ||
|
|
72708eb53c | ||
|
|
aae1670080 | ||
|
|
e70bedba7d | ||
|
|
1e776d2523 | ||
|
|
8e06e6abbc | ||
|
|
8a0e1b6cfc | ||
|
|
2d9bc79ca4 | ||
|
|
6886eb094d | ||
|
|
6ca0c38ee3 | ||
|
|
d633eb1612 | ||
|
|
1bbf2f269d | ||
|
|
ac22652686 | ||
|
|
77cfec5cc8 | ||
|
|
3e4420c1ae | ||
|
|
f8181ab1b3 | ||
|
|
3dfeead9b8 | ||
|
|
d3f6c7f983 | ||
|
|
390ce9f249 | ||
|
|
3da0be7eb9 | ||
|
|
8935ae0ea3 | ||
|
|
31e5f4bb0e | ||
|
|
2164674b01 | ||
|
|
8f2a646286 | ||
|
|
5ff4dd26bb | ||
|
|
e342ca872f | ||
|
|
a2aa66f43a | ||
|
|
da751da3dd | ||
|
|
2b7b3dd4ba | ||
|
|
dc1148106d | ||
|
|
062a369044 | ||
|
|
e4a2f56ad1 | ||
|
|
1df30f7260 | ||
|
|
514722d67a | ||
|
|
5dbde2116f | ||
|
|
14c4650801 | ||
|
|
f155b03eee | ||
|
|
ddaf753f7b | ||
|
|
e6d14c708c | ||
|
|
7f81a95b20 | ||
|
|
6a49eec606 | ||
|
|
fd67b18c9a | ||
|
|
9affdbbaad | ||
|
|
8d300bddd0 | ||
|
|
aa2c94be9e | ||
|
|
4c79350300 | ||
|
|
10e1d623c3 | ||
|
|
aa1f827271 | ||
|
|
fb113b9077 | ||
|
|
bb9460d278 | ||
|
|
6edeb4e072 | ||
|
|
2bb4e6d5aa | ||
|
|
53028feb83 | ||
|
|
d983dd371c | ||
|
|
17ee17a789 | ||
|
|
6b3ec29480 | ||
|
|
4a30773d09 | ||
|
|
006075483d | ||
|
|
1ea9ba84f5 | ||
|
|
64bd11541a | ||
|
|
52bd29d484 | ||
|
|
41b13e83a5 | ||
|
|
0d8f9cbe55 | ||
|
|
fd75a1dd10 | ||
|
|
bfdc8c80f3 | ||
|
|
3bb81bedbd | ||
|
|
e191f6d4b2 | ||
|
|
00988e4972 | ||
|
|
7d458eb1ac | ||
|
|
b8b46aec09 | ||
|
|
4d2b87ea01 | ||
|
|
8023a23cec | ||
|
|
e4c0102b3c | ||
|
|
16d044336f | ||
|
|
c4a2808a4b | ||
|
|
59716938bf | ||
|
|
611f31c057 | ||
|
|
b60adc31d0 | ||
|
|
a98ed3a5ba | ||
|
|
f057d5c85b | ||
|
|
918a0dedc0 | ||
|
|
218b6d0546 | ||
|
|
2183dba5c5 | ||
|
|
a491e326c5 | ||
|
|
f7bb4c3f05 | ||
|
|
57271ad125 | ||
|
|
33245b37ad | ||
|
|
81d8fb8762 | ||
|
|
fc9dacd082 | ||
|
|
8b4af69d87 | ||
|
|
989d3d7f3c | ||
|
|
d2a46b4308 | ||
|
|
eb1ba8d74b | ||
|
|
4ebde013ea | ||
|
|
024f92f9a9 | ||
|
|
562c937a14 | ||
|
|
5300e353d8 | ||
|
|
d78c97f8a8 | ||
|
|
52f61698e9 | ||
|
|
6f54fe9003 | ||
|
|
895917c3ab | ||
|
|
be00a837cc | ||
|
|
4d732e06de | ||
|
|
3ff8c87c09 | ||
|
|
f26a423e95 | ||
|
|
861c0fe76b | ||
|
|
c16da75ac7 | ||
|
|
36455f6cac | ||
|
|
2d0f932737 | ||
|
|
2aefa921fe | ||
|
|
8c449c4756 | ||
|
|
fc4e104c61 | ||
|
|
de73e4f5b9 | ||
|
|
0689e36390 | ||
|
|
78750042f5 | ||
|
|
13e7614508 | ||
|
|
4e1786d9ae | ||
|
|
585520d8d2 | ||
|
|
98b2734240 | ||
|
|
7b428b5240 | ||
|
|
ce08aa350c | ||
|
|
ba1a934297 | ||
|
|
4e90376d11 | ||
|
|
f73b45bcb5 | ||
|
|
23f4a4ea1a | ||
|
|
6aab8f16ce | ||
|
|
8f61413865 | ||
|
|
43b6a077fb | ||
|
|
e8299d0abb | ||
|
|
a28ab654ef | ||
|
|
8699fd7050 | ||
|
|
9e65470ada | ||
|
|
f4e52fafac | ||
|
|
ee7b36cea5 | ||
|
|
487455ef2e | ||
|
|
e201ad2f51 | ||
|
|
869f418b03 | ||
|
|
35d5ef9118 | ||
|
|
bcce70fca6 | ||
|
|
932112b640 | ||
|
|
91112167b1 | ||
|
|
bd7b59910d | ||
|
|
524888bf3b | ||
|
|
0327eae509 | ||
|
|
bb85608890 | ||
|
|
6c7668aaca | ||
|
|
7759b3f75a | ||
|
|
4d337f6abc | ||
|
|
92c86fd0b8 | ||
|
|
46dc751139 | ||
|
|
4cefe37723 | ||
|
|
82b73c50a0 | ||
|
|
7df7a95299 | ||
|
|
85b4b359c2 | ||
|
|
cfe81b5e00 | ||
|
|
b0c4451324 | ||
|
|
d4931522d4 | ||
|
|
17e2a35228 | ||
|
|
91016d8b29 | ||
|
|
9fda21cf40 | ||
|
|
809ec7163e | ||
|
|
7c9a939b47 | ||
|
|
9634c96020 | ||
|
|
e0c105f413 | ||
|
|
f0bf32c476 | ||
|
|
28373dbb98 | ||
|
|
4133d77772 | ||
|
|
61c426f502 | ||
|
|
bf0577c882 | ||
|
|
24673fd859 | ||
|
|
dc669d1447 | ||
|
|
ce4110b9f4 | ||
|
|
0f3b7d2b3d | ||
|
|
16dc78f6c6 | ||
|
|
7a66856785 | ||
|
|
c8dfa49d86 | ||
|
|
76dd749b1e | ||
|
|
67d05d2066 |
@@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path):
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('image1_path')
|
||||
parser.add_argument('image2_path')
|
||||
parser.add_argument("image1_path")
|
||||
parser.add_argument("image2_path")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
|
||||
print(mean_L1)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||
|
||||
4
.github/workflows/lint-frontend.yml
vendored
4
.github/workflows/lint-frontend.yml
vendored
@@ -2,8 +2,6 @@ name: Lint frontend
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -11,8 +9,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
27
.github/workflows/style-checks.yml
vendored
Normal file
27
.github/workflows/style-checks.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: style checks
|
||||
# just formatting for now
|
||||
# TODO: add isort and flake8 later
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
black:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install black
|
||||
|
||||
# - run: isort --check-only .
|
||||
- run: black --check .
|
||||
# - run: flake8
|
||||
50
.github/workflows/test-invoke-pip-skip.yml
vendored
50
.github/workflows/test-invoke-pip-skip.yml
vendored
@@ -1,50 +0,0 @@
|
||||
name: Test invoke.py pip
|
||||
|
||||
# This is a dummy stand-in for the actual tests
|
||||
# we don't need to run python tests on non-Python changes
|
||||
# But PRs require passing tests to be mergeable
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**'
|
||||
- '!pyproject.toml'
|
||||
- '!invokeai/**'
|
||||
- '!tests/**'
|
||||
- 'invokeai/frontend/web/**'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: skip
|
||||
run: echo "no build required"
|
||||
24
.github/workflows/test-invoke-pip.yml
vendored
24
.github/workflows/test-invoke-pip.yml
vendored
@@ -3,16 +3,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- 'tests/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
@@ -65,10 +56,23 @@ jobs:
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v37
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -76,6 +80,7 @@ jobs:
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
@@ -83,6 +88,7 @@ jobs:
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -38,7 +38,6 @@ develop-eggs/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
|
||||
10
.pre-commit-config.yaml
Normal file
10
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# See https://pre-commit.com/ for usage and config
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: black
|
||||
name: black
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: black
|
||||
types: [python]
|
||||
40
README.md
40
README.md
@@ -123,7 +123,7 @@ and go to http://localhost:9090.
|
||||
|
||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||
|
||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
||||
You must have Python 3.9 through 3.11 installed on your machine. Earlier or
|
||||
later versions are not supported.
|
||||
Node.js also needs to be installed along with yarn (can be installed with
|
||||
the command `npm install -g yarn` if needed)
|
||||
@@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
@@ -184,8 +184,9 @@ the command `npm install -g yarn` if needed)
|
||||
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
invokeai-configure --root .
|
||||
```
|
||||
Don't miss the dot at the end!
|
||||
|
||||
7. Launch the web server (do it every time you run InvokeAI):
|
||||
|
||||
@@ -193,15 +194,9 @@ the command `npm install -g yarn` if needed)
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
8. Build Node.js assets
|
||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
|
||||
```terminal
|
||||
cd invokeai/frontend/web/
|
||||
yarn vite build
|
||||
```
|
||||
|
||||
9. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
10. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
|
||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||
@@ -311,13 +306,30 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
|
||||
You may now launch the WebUI in the usual way, by selecting option [1]
|
||||
from the launcher script
|
||||
|
||||
#### Migration Caveats
|
||||
#### Migrating Images
|
||||
|
||||
The migration script will migrate your invokeai settings and models,
|
||||
including textual inversion models, LoRAs and merges that you may have
|
||||
installed previously. However it does **not** migrate the generated
|
||||
images stored in your 2.3-format outputs directory. You will need to
|
||||
manually import selected images into the 3.0 gallery via drag-and-drop.
|
||||
images stored in your 2.3-format outputs directory. To do this, you
|
||||
need to run an additional step:
|
||||
|
||||
1. From a working InvokeAI 3.0 root directory, start the launcher and
|
||||
enter menu option [8] to open the "developer's console".
|
||||
|
||||
2. At the developer's console command line, type the command:
|
||||
|
||||
```bash
|
||||
invokeai-import-images
|
||||
```
|
||||
|
||||
3. This will lead you through the process of confirming the desired
|
||||
source and destination for the imported images. The images will
|
||||
appear in the gallery board of your choice, and contain the
|
||||
original prompt, model name, and other parameters used to generate
|
||||
the image.
|
||||
|
||||
(Many kudos to **techjedi** for contributing this script.)
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ configure() {
|
||||
echo "To reconfigure InvokeAI, delete the above file."
|
||||
echo "======================================================================"
|
||||
else
|
||||
mkdir -p ${INVOKEAI_ROOT}
|
||||
chown --recursive ${USER} ${INVOKEAI_ROOT}
|
||||
mkdir -p "${INVOKEAI_ROOT}"
|
||||
chown --recursive ${USER} "${INVOKEAI_ROOT}"
|
||||
gosu ${USER} invokeai-configure --yes --default_only
|
||||
fi
|
||||
}
|
||||
@@ -50,16 +50,16 @@ fi
|
||||
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
|
||||
apt-get update
|
||||
apt-get install -y openssh-server
|
||||
pushd $HOME
|
||||
pushd "$HOME"
|
||||
mkdir -p .ssh
|
||||
echo ${PUBLIC_KEY} > .ssh/authorized_keys
|
||||
echo "${PUBLIC_KEY}" > .ssh/authorized_keys
|
||||
chmod -R 700 .ssh
|
||||
popd
|
||||
service ssh start
|
||||
fi
|
||||
|
||||
|
||||
cd ${INVOKEAI_ROOT}
|
||||
cd "${INVOKEAI_ROOT}"
|
||||
|
||||
# Run the CMD as the Container User (not root).
|
||||
exec gosu ${USER} "$@"
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 310 KiB After Width: | Height: | Size: 297 KiB |
BIN
docs/assets/troubleshooting/broken-dependency.png
Normal file
BIN
docs/assets/troubleshooting/broken-dependency.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 131 KiB |
@@ -16,7 +16,7 @@ If you don't feel ready to make a code contribution yet, no problem! You can als
|
||||
There are two paths to making a development contribution:
|
||||
|
||||
1. Choosing an open issue to address. Open issues can be found in the [Issues](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen) section of the InvokeAI repository. These are tagged by the issue type (bug, enhancement, etc.) along with the “good first issues” tag denoting if they are suitable for first time contributors.
|
||||
1. Additional items can be found on our roadmap <******************************link to roadmap>******************************. The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||
2. Opening a new issue or feature to add. **Please make sure you have searched through existing issues before creating new ones.**
|
||||
|
||||
*Regardless of what you choose, please post in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord before you start development in order to confirm that the issue or feature is aligned with the current direction of the project. We value our contributors time and effort and want to ensure that no one’s time is being misspent.*
|
||||
|
||||
@@ -4,35 +4,13 @@ title: Postprocessing
|
||||
|
||||
# :material-image-edit: Postprocessing
|
||||
|
||||
## Intro
|
||||
|
||||
This extension provides the ability to restore faces and upscale images.
|
||||
This sections details the ability to improve faces and upscale images.
|
||||
|
||||
## Face Fixing
|
||||
|
||||
The default face restoration module is GFPGAN. The default upscale is
|
||||
Real-ESRGAN. For an alternative face restoration module, see
|
||||
[CodeFormer Support](#codeformer-support) below.
|
||||
As of InvokeAI 3.0, the easiest way to improve faces created during image generation is through the Inpainting functionality of the Unified Canvas. Simply add the image containing the faces that you would like to improve to the canvas, mask the face to be improved and run the invocation. For best results, make sure to use an inpainting specific model; these are usually identified by the "-inpainting" term in the model name.
|
||||
|
||||
As of version 1.14, environment.yaml will install the Real-ESRGAN package into
|
||||
the standard install location for python packages, and will put GFPGAN into a
|
||||
subdirectory of "src" in the InvokeAI directory. Upscaling with Real-ESRGAN
|
||||
should "just work" without further intervention. Simply indicate the desired scale on
|
||||
the popup in the Web GUI.
|
||||
|
||||
**GFPGAN** requires a series of downloadable model files to work. These are
|
||||
loaded when you run `invokeai-configure`. If GFPAN is failing with an
|
||||
error, please run the following from the InvokeAI directory:
|
||||
|
||||
```bash
|
||||
invokeai-configure
|
||||
```
|
||||
|
||||
If you do not run this script in advance, the GFPGAN module will attempt to
|
||||
download the models files the first time you try to perform facial
|
||||
reconstruction.
|
||||
|
||||
### Upscaling
|
||||
## Upscaling
|
||||
|
||||
Open the upscaling dialog by clicking on the "expand" icon located
|
||||
above the image display area in the Web UI:
|
||||
@@ -41,82 +19,23 @@ above the image display area in the Web UI:
|
||||

|
||||
</figure>
|
||||
|
||||
There are three different upscaling parameters that you can
|
||||
adjust. The first is the scale itself, either 2x or 4x.
|
||||
The default upscaling option is Real-ESRGAN x2 Plus, which will scale your image by a factor of two. This means upscaling a 512x512 image will result in a new 1024x1024 image.
|
||||
|
||||
The second is the "Denoising Strength." Higher values will smooth out
|
||||
the image and remove digital chatter, but may lose fine detail at
|
||||
higher values.
|
||||
Other options are the x4 upscalers, which will scale your image by a factor of 4.
|
||||
|
||||
Third, "Upscale Strength" allows you to adjust how the You can set the
|
||||
scaling stength between `0` and `1.0` to control the intensity of the
|
||||
scaling. AI upscalers generally tend to smooth out texture details. If
|
||||
you wish to retain some of those for natural looking results, we
|
||||
recommend using values between `0.5 to 0.8`.
|
||||
|
||||
[This figure](../assets/features/upscaling-montage.png) illustrates
|
||||
the effects of denoising and strength. The original image was 512x512,
|
||||
4x scaled to 2048x2048. The "original" version on the upper left was
|
||||
scaled using simple pixel averaging. The remainder use the ESRGAN
|
||||
upscaling algorithm at different levels of denoising and strength.
|
||||
|
||||
<figure markdown>
|
||||
{ width=720 }
|
||||
</figure>
|
||||
|
||||
Both denoising and strength default to 0.75.
|
||||
|
||||
### Face Restoration
|
||||
|
||||
InvokeAI offers alternative two face restoration algorithms,
|
||||
[GFPGAN](https://github.com/TencentARC/GFPGAN) and
|
||||
[CodeFormer](https://huggingface.co/spaces/sczhou/CodeFormer). These
|
||||
algorithms improve the appearance of faces, particularly eyes and
|
||||
mouths. Issues with faces are less common with the latest set of
|
||||
Stable Diffusion models than with the original 1.4 release, but the
|
||||
restoration algorithms can still make a noticeable improvement in
|
||||
certain cases. You can also apply restoration to old photographs you
|
||||
upload.
|
||||
|
||||
To access face restoration, click the "smiley face" icon in the
|
||||
toolbar above the InvokeAI image panel. You will be presented with a
|
||||
dialog that offers a choice between the two algorithm and sliders that
|
||||
allow you to adjust their parameters. Alternatively, you may open the
|
||||
left-hand accordion panel labeled "Face Restoration" and have the
|
||||
restoration algorithm of your choice applied to generated images
|
||||
automatically.
|
||||
|
||||
|
||||
Like upscaling, there are a number of parameters that adjust the face
|
||||
restoration output. GFPGAN has a single parameter, `strength`, which
|
||||
controls how much the algorithm is allowed to adjust the
|
||||
image. CodeFormer has two parameters, `strength`, and `fidelity`,
|
||||
which together control the quality of the output image as described in
|
||||
the [CodeFormer project
|
||||
page](https://shangchenzhou.com/projects/CodeFormer/). Default values
|
||||
are 0.75 for both parameters, which achieves a reasonable balance
|
||||
between changing the image too much and not enough.
|
||||
|
||||
[This figure](../assets/features/restoration-montage.png) illustrates
|
||||
the effects of adjusting GFPGAN and CodeFormer parameters.
|
||||
|
||||
<figure markdown>
|
||||
{ width=720 }
|
||||
</figure>
|
||||
|
||||
!!! note
|
||||
|
||||
GFPGAN and Real-ESRGAN are both memory intensive. In order to avoid crashes and memory overloads
|
||||
Real-ESRGAN is memory intensive. In order to avoid crashes and memory overloads
|
||||
during the Stable Diffusion process, these effects are applied after Stable Diffusion has completed
|
||||
its work.
|
||||
|
||||
In single image generations, you will see the output right away but when you are using multiple
|
||||
iterations, the images will first be generated and then upscaled and face restored after that
|
||||
iterations, the images will first be generated and then upscaled after that
|
||||
process is complete. While the image generation is taking place, you will still be able to preview
|
||||
the base images.
|
||||
|
||||
## How to disable
|
||||
|
||||
If, for some reason, you do not wish to load the GFPGAN and/or ESRGAN libraries,
|
||||
you can disable them on the invoke.py command line with the `--no_restore` and
|
||||
`--no_esrgan` options, respectively.
|
||||
If, for some reason, you do not wish to load the ESRGAN libraries,
|
||||
you can disable them on the invoke.py command line with the `--no_esrgan` options.
|
||||
|
||||
@@ -4,6 +4,9 @@ title: Overview
|
||||
|
||||
Here you can find the documentation for InvokeAI's various features.
|
||||
|
||||
## The [Getting Started Guide](../help/gettingStartedWithAI)
|
||||
A getting started guide for those new to AI image generation.
|
||||
|
||||
## The Basics
|
||||
### * The [Web User Interface](WEB.md)
|
||||
Guide to the Web interface. Also see the [WebUI Hotkeys Reference Guide](WEBUIHOTKEYS.md)
|
||||
@@ -46,7 +49,7 @@ Personalize models by adding your own style or subjects.
|
||||
|
||||
## Other Features
|
||||
|
||||
### * [The NSFW Checker](NSFW.md)
|
||||
### * [The NSFW Checker](WATERMARK+NSFW.md)
|
||||
Prevent InvokeAI from displaying unwanted racy images.
|
||||
|
||||
### * [Controlling Logging](LOGGING.md)
|
||||
|
||||
95
docs/help/gettingStartedWithAI.md
Normal file
95
docs/help/gettingStartedWithAI.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# Getting Started with AI Image Generation
|
||||
|
||||
New to image generation with AI? You’re in the right place!
|
||||
|
||||
This is a high level walkthrough of some of the concepts and terms you’ll see as you start using InvokeAI. Please note, this is not an exhaustive guide and may be out of date due to the rapidly changing nature of the space.
|
||||
|
||||
## Using InvokeAI
|
||||
|
||||
### **Prompt Crafting**
|
||||
|
||||
- Prompts are the basis of using InvokeAI, providing the models directions on what to generate. As a general rule of thumb, the more detailed your prompt is, the better your result will be.
|
||||
|
||||
*To get started, here’s an easy template to use for structuring your prompts:*
|
||||
|
||||
- Subject, Style, Quality, Aesthetic
|
||||
- **Subject:** What your image will be about. E.g. “a futuristic city with trains”, “penguins floating on icebergs”, “friends sharing beers”
|
||||
- **Style:** The style or medium in which your image will be in. E.g. “photograph”, “pencil sketch”, “oil paints”, or “pop art”, “cubism”, “abstract”
|
||||
- **Quality:** A particular aspect or trait that you would like to see emphasized in your image. E.g. "award-winning", "featured in {relevant set of high quality works}", "professionally acclaimed". Many people often use "masterpiece".
|
||||
- **Aesthetics:** The visual impact and design of the artwork. This can be colors, mood, lighting, setting, etc.
|
||||
- There are two prompt boxes: *Positive Prompt* & *Negative Prompt*.
|
||||
- A **Positive** Prompt includes words you want the model to reference when creating an image.
|
||||
- Negative Prompt is for anything you want the model to eliminate when creating an image. It doesn’t always interpret things exactly the way you would, but helps control the generation process. Always try to include a few terms - you can typically use lower quality image terms like “blurry” or “distorted” with good success.
|
||||
- Some examples prompts you can try on your own:
|
||||
- A detailed oil painting of a tranquil forest at sunset with vibrant+ colors and soft, golden light filtering through the trees
|
||||
- friends sharing beers in a busy city, realistic colored pencil sketch, twilight, masterpiece, bright, lively
|
||||
|
||||
### Generation Workflows
|
||||
|
||||
- Invoke offers a number of different workflows for interacting with models to produce images. Each is extremely powerful on its own, but together provide you an unparalleled way of producing high quality creative outputs that align with your vision.
|
||||
- **Text to Image:** The text to image tab focuses on the key workflow of using a prompt to generate a new image. It includes other features that help control the generation process as well.
|
||||
- **Image to Image:** With image to image, you provide an image as a reference (called the “initial image”), which provides more guidance around color and structure to the AI as it generates a new image. This is provided alongside the same features as Text to Image.
|
||||
- **Unified Canvas:** The Unified Canvas is an advanced AI-first image editing tool that is easy to use, but hard to master. Drag an image onto the canvas from your gallery in order to regenerate certain elements, edit content or colors (known as inpainting), or extend the image with an exceptional degree of consistency and clarity (called outpainting).
|
||||
|
||||
### Improving Image Quality
|
||||
|
||||
- Fine tuning your prompt - the more specific you are, the closer the image will turn out to what is in your head! Adding more details in the Positive Prompt or Negative Prompt can help add / remove pieces of your image to improve it - You can also use advanced techniques like upweighting and downweighting to control the influence of certain words. [Learn more here](https://invoke-ai.github.io/InvokeAI/features/PROMPTS/#prompt-syntax-features).
|
||||
- **Tip: If you’re seeing poor results, try adding the things you don’t like about the image to your negative prompt may help. E.g. distorted, low quality, unrealistic, etc.**
|
||||
- Explore different models - Other models can produce different results due to the data they’ve been trained on. Each model has specific language and settings it works best with; a model’s documentation is your friend here. Play around with some and see what works best for you!
|
||||
- Increasing Steps - The number of steps used controls how much time the model is given to produce an image, and depends on the “Scheduler” used. The schedule controls how each step is processed by the model. More steps tends to mean better results, but will take longer - We recommend at least 30 steps for most
|
||||
- Tweak and Iterate - Remember, it’s best to change one thing at a time so you know what is working and what isn't. Sometimes you just need to try a new image, and other times using a new prompt might be the ticket. For testing, consider turning off the “random” Seed - Using the same seed with the same settings will produce the same image, which makes it the perfect way to learn exactly what your changes are doing.
|
||||
- Explore Advanced Settings - InvokeAI has a full suite of tools available to allow you complete control over your image creation process - Check out our [docs if you want to learn more](https://invoke-ai.github.io/InvokeAI/features/).
|
||||
|
||||
|
||||
## Terms & Concepts
|
||||
|
||||
If you're interested in learning more, check out [this presentation](https://docs.google.com/presentation/d/1IO78i8oEXFTZ5peuHHYkVF-Y3e2M6iM5tCnc-YBfcCM/edit?usp=sharing) from one of our maintainers (@lstein).
|
||||
|
||||
### Stable Diffusion
|
||||
|
||||
Stable Diffusion is deep learning, text-to-image model that is the foundation of the capabilities found in InvokeAI. Since the release of Stable Diffusion, there have been many subsequent models created based on Stable Diffusion that are designed to generate specific types of images.
|
||||
|
||||
### Prompts
|
||||
|
||||
Prompts provide the models directions on what to generate. As a general rule of thumb, the more detailed your prompt is, the better your result will be.
|
||||
|
||||
### Models
|
||||
|
||||
Models are the magic that power InvokeAI. These files represent the output of training a machine on understanding massive amounts of images - providing them with the capability to generate new images using just a text description of what you’d like to see. (Like Stable Diffusion!)
|
||||
|
||||
Invoke offers a simple way to download several different models upon installation, but many more can be discovered online, including at ****. Each model can produce a unique style of output, based on the images it was trained on - Try out different models to see which best fits your creative vision!
|
||||
|
||||
- *Models that contain “inpainting” in the name are designed for use with the inpainting feature of the Unified Canvas*
|
||||
|
||||
### Scheduler
|
||||
|
||||
Schedulers guide the process of removing noise (de-noising) from data. They determine:
|
||||
|
||||
1. The number of steps to take to remove the noise.
|
||||
2. Whether the steps are random (stochastic) or predictable (deterministic).
|
||||
3. The specific method (algorithm) used for de-noising.
|
||||
|
||||
Experimenting with different schedulers is recommended as each will produce different outputs!
|
||||
|
||||
### Steps
|
||||
|
||||
The number of de-noising steps each generation through.
|
||||
|
||||
Schedulers can be intricate and there's often a balance to strike between how quickly they can de-noise data and how well they can do it. It's typically advised to experiment with different schedulers to see which one gives the best results. There has been a lot written on the internet about different schedulers, as well as exploring what the right level of "steps" are for each. You can save generation time by reducing the number of steps used, but you'll want to make sure that you are satisfied with the quality of images produced!
|
||||
|
||||
### Low-Rank Adaptations / LoRAs
|
||||
|
||||
Low-Rank Adaptations (LoRAs) are like a smaller, more focused version of models, intended to focus on training a better understanding of how a specific character, style, or concept looks.
|
||||
|
||||
### Textual Inversion Embeddings
|
||||
|
||||
Textual Inversion Embeddings, like LoRAs, assist with more easily prompting for certain characters, styles, or concepts. However, embeddings are trained to update the relationship between a specific word (known as the “trigger”) and the intended output.
|
||||
|
||||
### ControlNet
|
||||
|
||||
ControlNets are neural network models that are able to extract key features from an existing image and use these features to guide the output of the image generation model.
|
||||
|
||||
### VAE
|
||||
|
||||
Variational auto-encoder (VAE) is a encode/decode model that translates the "latents" image produced during the image generation procees to the large pixel images that we see.
|
||||
|
||||
@@ -11,6 +11,33 @@ title: Home
|
||||
```
|
||||
-->
|
||||
|
||||
<!-- CSS styling -->
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fortawesome/fontawesome-free@6.2.1/css/fontawesome.min.css">
|
||||
<style>
|
||||
.button {
|
||||
width: 300px;
|
||||
height: 50px;
|
||||
background-color: #448AFF;
|
||||
color: #fff;
|
||||
font-size: 16px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
border-radius: 0.2rem;
|
||||
}
|
||||
|
||||
.button-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 300px);
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.button:hover {
|
||||
background-color: #526CFE;
|
||||
}
|
||||
</style>
|
||||
|
||||
|
||||
|
||||
<div align="center" markdown>
|
||||
|
||||
|
||||
@@ -70,61 +97,23 @@ image-to-image generator. It provides a streamlined process with various new
|
||||
features and options to aid the image generation process. It runs on Windows,
|
||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
|
||||
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a
|
||||
href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a
|
||||
href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas &
|
||||
Q&A</a>]
|
||||
|
||||
<div align="center"><img src="assets/invoke-web-server-1.png" width=640></div>
|
||||
|
||||
!!! note
|
||||
!!! Note
|
||||
|
||||
This fork is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster.
|
||||
This project is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates as it will help aid response time.
|
||||
|
||||
## :octicons-package-dependencies-24: Installation
|
||||
## :octicons-link-24: Quick Links
|
||||
|
||||
This fork is supported across Linux, Windows and Macintosh. Linux users can use
|
||||
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
|
||||
driver).
|
||||
|
||||
### [Installation Getting Started Guide](installation)
|
||||
#### **[Automated Installer](installation/010_INSTALL_AUTOMATED.md)**
|
||||
✅ This is the recommended installation method for first-time users.
|
||||
#### [Manual Installation](installation/020_INSTALL_MANUAL.md)
|
||||
This method is recommended for experienced users and developers
|
||||
#### [Docker Installation](installation/040_INSTALL_DOCKER.md)
|
||||
This method is recommended for those familiar with running Docker containers
|
||||
### Other Installation Guides
|
||||
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
|
||||
- [XFormers](installation/070_INSTALL_XFORMERS.md)
|
||||
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
|
||||
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
|
||||
|
||||
## :fontawesome-solid-computer: Hardware Requirements
|
||||
|
||||
### :octicons-cpu-24: System
|
||||
|
||||
You wil need one of the following:
|
||||
|
||||
- :simple-nvidia: An NVIDIA-based graphics card with 4 GB or more VRAM memory.
|
||||
- :simple-amd: An AMD-based graphics card with 4 GB or more VRAM memory (Linux
|
||||
only)
|
||||
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
|
||||
|
||||
We do **not recommend** the following video cards due to issues with their
|
||||
running in half-precision mode and having insufficient VRAM to render 512x512
|
||||
images in full-precision mode:
|
||||
|
||||
- NVIDIA 10xx series cards such as the 1080ti
|
||||
- GTX 1650 series cards
|
||||
- GTX 1660 series cards
|
||||
|
||||
### :fontawesome-solid-memory: Memory and Disk
|
||||
|
||||
- At least 12 GB Main Memory RAM.
|
||||
- At least 18 GB of free disk space for the machine learning model, Python, and
|
||||
all its dependencies.
|
||||
<div class="button-container">
|
||||
<a href="installation/INSTALLATION"> <button class="button">Installation</button> </a>
|
||||
<a href="features/"> <button class="button">Features</button> </a>
|
||||
<a href="help/gettingStartedWithAI/"> <button class="button">Getting Started</button> </a>
|
||||
<a href="contributing/CONTRIBUTING/"> <button class="button">Contributing</button> </a>
|
||||
<a href="https://github.com/invoke-ai/InvokeAI/"> <button class="button">Code and Downloads</button> </a>
|
||||
<a href="https://github.com/invoke-ai/InvokeAI/issues"> <button class="button">Bug Reports </button> </a>
|
||||
<a href="https://discord.gg/ZmtBAhwWhy"> <button class="button"> Join the Discord Server!</button> </a>
|
||||
</div>
|
||||
|
||||
|
||||
## :octicons-gift-24: InvokeAI Features
|
||||
@@ -230,7 +219,7 @@ encouraged to do so.
|
||||
|
||||
## :octicons-person-24: Contributors
|
||||
|
||||
This fork is a combined effort of various people from across the world.
|
||||
This software is a combined effort of various people from across the world.
|
||||
[Check out the list of all these amazing people](other/CONTRIBUTORS.md). We
|
||||
thank them for their time, hard work and effort.
|
||||
|
||||
|
||||
@@ -40,10 +40,8 @@ experimental versions later.
|
||||
this, open up a command-line window ("Terminal" on Linux and
|
||||
Macintosh, "Command" or "Powershell" on Windows) and type `python
|
||||
--version`. If Python is installed, it will print out the version
|
||||
number. If it is version `3.9.*` or `3.10.*`, you meet
|
||||
requirements. We do not recommend using Python 3.11 or higher,
|
||||
as not all the libraries that InvokeAI depends on work properly
|
||||
with this version.
|
||||
number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
|
||||
requirements.
|
||||
|
||||
!!! warning "What to do if you have an unsupported version"
|
||||
|
||||
@@ -266,7 +264,7 @@ experimental versions later.
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***Autoimport FolderLICENSE***
|
||||
- ***LICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
@@ -374,8 +372,71 @@ experimental versions later.
|
||||
Once InvokeAI is installed, do not move or remove this directory."
|
||||
|
||||
|
||||
<a name="troubleshooting"></a>
|
||||
## Troubleshooting
|
||||
|
||||
### _OSErrors on Windows while installing dependencies_
|
||||
|
||||
During a zip file installation or an online update, installation stops
|
||||
with an error like this:
|
||||
|
||||
{:width="800px"}
|
||||
|
||||
This seems to happen particularly often with the `pydantic` and
|
||||
`numpy` packages. The most reliable solution requires several manual
|
||||
steps to complete installation.
|
||||
|
||||
Open up a Powershell window and navigate to the `invokeai` directory
|
||||
created by the installer. Then give the following series of commands:
|
||||
|
||||
```cmd
|
||||
rm .\.venv -r -force
|
||||
python -mvenv .venv
|
||||
.\.venv\Scripts\activate
|
||||
pip install invokeai
|
||||
invokeai-configure --yes --root .
|
||||
```
|
||||
|
||||
If you see anything marked as an error during this process please stop
|
||||
and seek help on the Discord [installation support
|
||||
channel](https://discord.com/channels/1020123559063990373/1041391462190956654). A
|
||||
few warning messages are OK.
|
||||
|
||||
If you are updating from a previous version, this should restore your
|
||||
system to a working state. If you are installing from scratch, there
|
||||
is one additional command to give:
|
||||
|
||||
```cmd
|
||||
wget -O invoke.bat https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/installer/templates/invoke.bat.in
|
||||
```
|
||||
|
||||
This will create the `invoke.bat` script needed to launch InvokeAI and
|
||||
its related programs.
|
||||
|
||||
|
||||
### _Stable Diffusion XL Generation Fails after Trying to Load unet_
|
||||
|
||||
InvokeAI is working in other respects, but when trying to generate
|
||||
images with Stable Diffusion XL you get a "Server Error". The text log
|
||||
in the launch window contains this log line above several more lines of
|
||||
error messages:
|
||||
|
||||
```INFO --> Loading model:D:\LONG\PATH\TO\MODEL, type sdxl:main:unet```
|
||||
|
||||
This failure mode occurs when there is a network glitch during
|
||||
downloading the very large SDXL model.
|
||||
|
||||
To address this, first go to the Web Model Manager and delete the
|
||||
Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
|
||||
manually download the .safetensors version of the model. The 1.0
|
||||
version is located at
|
||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
|
||||
and the file is named `sd_xl_base_1.0.safetensors`.
|
||||
|
||||
Save this file to disk and then reenter the Model Manager. Navigate to
|
||||
Import Models->Add Model, then type (or drag-and-drop) the path to the
|
||||
.safetensors file. Press "Add Model".
|
||||
|
||||
### _Package dependency conflicts_
|
||||
|
||||
If you have previously installed InvokeAI or another Stable Diffusion
|
||||
@@ -410,7 +471,7 @@ Then type the following commands:
|
||||
|
||||
=== "NVIDIA System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ gaming):
|
||||
|
||||
* **Python**
|
||||
|
||||
version 3.9 or 3.10 (3.11 is not recommended).
|
||||
version 3.9 through 3.11
|
||||
|
||||
* **CUDA Tools**
|
||||
|
||||
@@ -65,7 +65,7 @@ gaming):
|
||||
To install InvokeAI with virtual environments and the PIP package
|
||||
manager, please follow these steps:
|
||||
|
||||
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install
|
||||
1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
|
||||
procedure depends on this and will not work with other versions:
|
||||
|
||||
```bash
|
||||
@@ -148,7 +148,7 @@ manager, please follow these steps:
|
||||
=== "CUDA (NVidia)"
|
||||
|
||||
```bash
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -192,8 +192,10 @@ manager, please follow these steps:
|
||||
your outputs.
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
invokeai-configure --root .
|
||||
```
|
||||
|
||||
Don't miss the dot at the end of the command!
|
||||
|
||||
The script `invokeai-configure` will interactively guide you through the
|
||||
process of downloading and installing the weights files needed for InvokeAI.
|
||||
@@ -225,12 +227,6 @@ manager, please follow these steps:
|
||||
|
||||
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
|
||||
|
||||
=== "CLI"
|
||||
|
||||
```bash
|
||||
invokeai
|
||||
```
|
||||
|
||||
=== "local Webserver"
|
||||
|
||||
```bash
|
||||
@@ -243,6 +239,12 @@ manager, please follow these steps:
|
||||
invokeai --web --host 0.0.0.0
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
||||
```bash
|
||||
invokeai
|
||||
```
|
||||
|
||||
If you choose the run the web interface, point your browser at
|
||||
http://localhost:9090 in order to load the GUI.
|
||||
|
||||
@@ -310,7 +312,7 @@ installation protocol (important!)
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
```bash
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
@@ -354,7 +356,7 @@ you can do so using this unsupported recipe:
|
||||
mkdir ~/invokeai
|
||||
conda create -n invokeai python=3.10
|
||||
conda activate invokeai
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
invokeai-configure --root ~/invokeai
|
||||
invokeai --root ~/invokeai --web
|
||||
```
|
||||
|
||||
@@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
|
||||
nvidia-cuda-toolkit package. It is out of date and will cause
|
||||
conflicts among the NVIDIA driver and binaries.**
|
||||
|
||||
Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive),
|
||||
and use the target selection wizard to choose your operating system,
|
||||
hardware platform, and preferred installation method (e.g. "local"
|
||||
versus "network").
|
||||
Go to [CUDA Toolkit
|
||||
Downloads](https://developer.nvidia.com/cuda-downloads), and use the
|
||||
target selection wizard to choose your operating system, hardware
|
||||
platform, and preferred installation method (e.g. "local" versus
|
||||
"network").
|
||||
|
||||
This will provide you with a downloadable install file or, depending
|
||||
on your choices, a recipe for downloading and running a install shell
|
||||
@@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/cu117` as described in the [Manual
|
||||
https://download.pytorch.org/whl/cu118` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
## :simple-amd: ROCm
|
||||
|
||||
@@ -124,7 +124,7 @@ installation. Examples:
|
||||
invokeai-model-install --list controlnet
|
||||
|
||||
# (install the model at the indicated URL)
|
||||
invokeai-model-install --add http://civitai.com/2860
|
||||
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||
|
||||
# (delete the named model)
|
||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||
@@ -170,4 +170,4 @@ elsewhere on disk and they will be autoimported. You can also create
|
||||
subfolders and organize them as you wish.
|
||||
|
||||
The location of the autoimport directories are controlled by settings
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
|
||||
@@ -28,18 +28,21 @@ command line, then just be sure to activate it's virtual environment.
|
||||
Then run the following three commands:
|
||||
|
||||
```sh
|
||||
pip install xformers==0.0.16rc425
|
||||
pip install triton
|
||||
pip install xformers~=0.0.19
|
||||
pip install triton # WON'T WORK ON WINDOWS
|
||||
python -m xformers.info output
|
||||
```
|
||||
|
||||
The first command installs `xformers`, the second installs the
|
||||
`triton` training accelerator, and the third prints out the `xformers`
|
||||
installation status. If all goes well, you'll see a report like the
|
||||
installation status. On Windows, please omit the `triton` package,
|
||||
which is not available on that platform.
|
||||
|
||||
If all goes well, you'll see a report like the
|
||||
following:
|
||||
|
||||
```sh
|
||||
xFormers 0.0.16rc425
|
||||
xFormers 0.0.20
|
||||
memory_efficient_attention.cutlassF: available
|
||||
memory_efficient_attention.cutlassB: available
|
||||
memory_efficient_attention.flshattF: available
|
||||
@@ -48,22 +51,28 @@ memory_efficient_attention.smallkF: available
|
||||
memory_efficient_attention.smallkB: available
|
||||
memory_efficient_attention.tritonflashattF: available
|
||||
memory_efficient_attention.tritonflashattB: available
|
||||
indexing.scaled_index_addF: available
|
||||
indexing.scaled_index_addB: available
|
||||
indexing.index_select: available
|
||||
swiglu.dual_gemm_silu: available
|
||||
swiglu.gemm_fused_operand_sum: available
|
||||
swiglu.fused.p.cpp: available
|
||||
is_triton_available: True
|
||||
is_functorch_available: False
|
||||
pytorch.version: 1.13.1+cu117
|
||||
pytorch.version: 2.0.1+cu118
|
||||
pytorch.cuda: available
|
||||
gpu.compute_capability: 8.6
|
||||
gpu.name: NVIDIA RTX A2000 12GB
|
||||
gpu.compute_capability: 8.9
|
||||
gpu.name: NVIDIA GeForce RTX 4070
|
||||
build.info: available
|
||||
build.cuda_version: 1107
|
||||
build.python_version: 3.10.9
|
||||
build.torch_version: 1.13.1+cu117
|
||||
build.cuda_version: 1108
|
||||
build.python_version: 3.10.11
|
||||
build.torch_version: 2.0.1+cu118
|
||||
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
|
||||
build.env.XFORMERS_BUILD_TYPE: Release
|
||||
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
|
||||
build.env.NVCC_FLAGS: None
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425
|
||||
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20
|
||||
build.nvcc_version: 11.8.89
|
||||
source.privacy: open source
|
||||
```
|
||||
|
||||
@@ -83,14 +92,14 @@ installed from source. These instructions were written for a system
|
||||
running Ubuntu 22.04, but other Linux distributions should be able to
|
||||
adapt this recipe.
|
||||
|
||||
#### 1. Install CUDA Toolkit 11.7
|
||||
#### 1. Install CUDA Toolkit 11.8
|
||||
|
||||
You will need the CUDA developer's toolkit in order to compile and
|
||||
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
|
||||
package.** It is out of date and will cause conflicts among the NVIDIA
|
||||
driver and binaries. Instead install the CUDA Toolkit package provided
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.7
|
||||
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive)
|
||||
by NVIDIA itself. Go to [CUDA Toolkit 11.8
|
||||
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
|
||||
and use the target selection wizard to choose your platform and Linux
|
||||
distribution. Select an installer type of "runfile (local)" at the
|
||||
last step.
|
||||
@@ -101,17 +110,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
|
||||
x86_64 system is:
|
||||
|
||||
```
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
|
||||
sudo sh cuda_11.7.0_515.43.04_linux.run
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
sudo sh cuda_11.8.0_520.61.05_linux.run
|
||||
```
|
||||
|
||||
Rather than cut-and-paste this example, We recommend that you walk
|
||||
through the toolkit wizard in order to get the most up to date
|
||||
installer for your system.
|
||||
|
||||
#### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support
|
||||
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
|
||||
|
||||
If you are using InvokeAI 2.3 or higher, these will already be
|
||||
If you are using InvokeAI 3.0.2 or higher, these will already be
|
||||
installed. If not, you can check whether you have the needed libraries
|
||||
using a quick command. Activate the invokeai virtual environment,
|
||||
either by entering the "developer's console", or manually with a
|
||||
@@ -124,7 +133,7 @@ Then run the command:
|
||||
python -c 'exec("import torch\nprint(torch.__version__)")'
|
||||
```
|
||||
|
||||
If it prints __1.13.1+cu117__ you're good. If not, you can install the
|
||||
If it prints __1.13.1+cu118__ you're good. If not, you can install the
|
||||
most up to date libraries with this command:
|
||||
|
||||
```sh
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
---
|
||||
title: Overview
|
||||
---
|
||||
# Overview
|
||||
|
||||
We offer several ways to install InvokeAI, each one suited to your
|
||||
experience and preferences. We suggest that everyone start by
|
||||
@@ -15,6 +13,56 @@ See the [troubleshooting
|
||||
section](010_INSTALL_AUTOMATED.md#troubleshooting) of the automated
|
||||
install guide for frequently-encountered installation issues.
|
||||
|
||||
This fork is supported across Linux, Windows and Macintosh. Linux users can use
|
||||
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
|
||||
driver).
|
||||
|
||||
### [Installation Getting Started Guide](installation)
|
||||
#### **[Automated Installer](010_INSTALL_AUTOMATED.md)**
|
||||
✅ This is the recommended installation method for first-time users.
|
||||
#### [Manual Installation](020_INSTALL_MANUAL.md)
|
||||
This method is recommended for experienced users and developers
|
||||
#### [Docker Installation](040_INSTALL_DOCKER.md)
|
||||
This method is recommended for those familiar with running Docker containers
|
||||
### Other Installation Guides
|
||||
- [PyPatchMatch](060_INSTALL_PATCHMATCH.md)
|
||||
- [XFormers](070_INSTALL_XFORMERS.md)
|
||||
- [CUDA and ROCm Drivers](030_INSTALL_CUDA_AND_ROCM.md)
|
||||
- [Installing New Models](050_INSTALLING_MODELS.md)
|
||||
|
||||
## :fontawesome-solid-computer: Hardware Requirements
|
||||
|
||||
### :octicons-cpu-24: System
|
||||
|
||||
You wil need one of the following:
|
||||
|
||||
- :simple-nvidia: An NVIDIA-based graphics card with 4 GB or more VRAM memory.
|
||||
- :simple-amd: An AMD-based graphics card with 4 GB or more VRAM memory (Linux
|
||||
only)
|
||||
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
|
||||
|
||||
** SDXL 1.0 Requirements*
|
||||
To use SDXL, user must have one of the following:
|
||||
- :simple-nvidia: An NVIDIA-based graphics card with 8 GB or more VRAM memory.
|
||||
- :simple-amd: An AMD-based graphics card with 16 GB or more VRAM memory (Linux
|
||||
only)
|
||||
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
|
||||
|
||||
|
||||
### :fontawesome-solid-memory: Memory and Disk
|
||||
|
||||
- At least 12 GB Main Memory RAM.
|
||||
- At least 18 GB of free disk space for the machine learning model, Python, and
|
||||
all its dependencies.
|
||||
|
||||
We do **not recommend** the following video cards due to issues with their
|
||||
running in half-precision mode and having insufficient VRAM to render 512x512
|
||||
images in full-precision mode:
|
||||
|
||||
- NVIDIA 10xx series cards such as the 1080ti
|
||||
- GTX 1650 series cards
|
||||
- GTX 1660 series cards
|
||||
|
||||
## Installation options
|
||||
|
||||
1. [Automated Installer](010_INSTALL_AUTOMATED.md)
|
||||
@@ -14,23 +14,28 @@ The nodes linked below have been developed and contributed by members of the Inv
|
||||
|
||||
## List of Nodes
|
||||
|
||||
### Face Mask
|
||||
### FaceTools
|
||||
|
||||
**Description:** This node autodetects a face in the image using MediaPipe and masks it by making it transparent. Via outpainting you can swap faces with other faces, or invert the mask and swap things around the face with other things. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control. The node also outputs an all-white mask in the same dimensions as the input image. This is needed by the inpaint node (and unified canvas) for outpainting.
|
||||
**Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
|
||||
|
||||
**Node Link:** https://github.com/ymgenesis/InvokeAI/blob/facemaskmediapipe/invokeai/app/invocations/facemask.py
|
||||
**Node Link:** https://github.com/ymgenesis/FaceTools/
|
||||
|
||||
**Example Node Graph:** https://www.mediafire.com/file/gohn5sb1bfp8use/21-July_2023-FaceMask.json/file
|
||||
**FaceMask Output Examples**
|
||||
|
||||
**Output Examples**
|
||||

|
||||

|
||||

|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
<hr>
|
||||
|
||||
### Ideal Size
|
||||
|
||||
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
||||
|
||||
--------------------------------
|
||||
### Super Cool Node Template
|
||||
### Example Node Template
|
||||
|
||||
**Description:** This node allows you to do super cool things with InvokeAI.
|
||||
|
||||
@@ -40,13 +45,9 @@ The nodes linked below have been developed and contributed by members of the Inv
|
||||
|
||||
**Output Examples**
|
||||
|
||||

|
||||
|
||||
### Ideal Size
|
||||
|
||||
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
||||
{: style="height:115px;width:240px"}
|
||||
|
||||
## Help
|
||||
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).
|
||||
|
||||
|
||||
|
||||
25
flake.lock
generated
Normal file
25
flake.lock
generated
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1690630721,
|
||||
"narHash": "sha256-Y04onHyBQT4Erfr2fc82dbJTfXGYrf4V0ysLUYnPOP8=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "d2b52322f35597c62abf56de91b0236746b2a03d",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
91
flake.nix
Normal file
91
flake.nix
Normal file
@@ -0,0 +1,91 @@
|
||||
# Important note: this flake does not attempt to create a fully isolated, 'pure'
|
||||
# Python environment for InvokeAI. Instead, it depends on local invocations of
|
||||
# virtualenv/pip to install the required (binary) packages, most importantly the
|
||||
# prebuilt binary pytorch packages with CUDA support.
|
||||
# ML Python packages with CUDA support, like pytorch, are notoriously expensive
|
||||
# to compile so it's purposefuly not what this flake does.
|
||||
|
||||
{
|
||||
description = "An (impure) flake to develop on InvokeAI.";
|
||||
|
||||
outputs = { self, nixpkgs }:
|
||||
let
|
||||
system = "x86_64-linux";
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfree = true;
|
||||
};
|
||||
|
||||
python = pkgs.python310;
|
||||
|
||||
mkShell = { dir, install }:
|
||||
let
|
||||
setupScript = pkgs.writeScript "setup-invokai" ''
|
||||
# This must be sourced using 'source', not executed.
|
||||
${python}/bin/python -m venv ${dir}
|
||||
${dir}/bin/python -m pip install ${install}
|
||||
# ${dir}/bin/python -c 'import torch; assert(torch.cuda.is_available())'
|
||||
source ${dir}/bin/activate
|
||||
'';
|
||||
in
|
||||
pkgs.mkShell rec {
|
||||
buildInputs = with pkgs; [
|
||||
# Backend: graphics, CUDA.
|
||||
cudaPackages.cudnn
|
||||
cudaPackages.cuda_nvrtc
|
||||
cudatoolkit
|
||||
pkgconfig
|
||||
libconfig
|
||||
cmake
|
||||
blas
|
||||
freeglut
|
||||
glib
|
||||
gperf
|
||||
procps
|
||||
libGL
|
||||
libGLU
|
||||
linuxPackages.nvidia_x11
|
||||
python
|
||||
(opencv4.override {
|
||||
enableGtk3 = true;
|
||||
enableFfmpeg = true;
|
||||
enableCuda = true;
|
||||
enableUnfree = true;
|
||||
})
|
||||
stdenv.cc
|
||||
stdenv.cc.cc.lib
|
||||
xorg.libX11
|
||||
xorg.libXext
|
||||
xorg.libXi
|
||||
xorg.libXmu
|
||||
xorg.libXrandr
|
||||
xorg.libXv
|
||||
zlib
|
||||
|
||||
# Pre-commit hooks.
|
||||
black
|
||||
|
||||
# Frontend.
|
||||
yarn
|
||||
nodejs
|
||||
];
|
||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
|
||||
CUDA_PATH = pkgs.cudatoolkit;
|
||||
EXTRA_LDFLAGS = "-L${pkgs.linuxPackages.nvidia_x11}/lib";
|
||||
shellHook = ''
|
||||
if [[ -f "${dir}/bin/activate" ]]; then
|
||||
source "${dir}/bin/activate"
|
||||
echo "Using Python: $(which python)"
|
||||
else
|
||||
echo "Use 'source ${setupScript}' to set up the environment."
|
||||
fi
|
||||
'';
|
||||
};
|
||||
in
|
||||
{
|
||||
devShells.${system} = rec {
|
||||
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
|
||||
default = develop;
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -9,16 +9,20 @@ cd $scriptdir
|
||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||
|
||||
MINIMUM_PYTHON_VERSION=3.9.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||
PYTHON=""
|
||||
for candidate in python3.10 python3.9 python3 python ; do
|
||||
for candidate in python3.11 python3.10 python3.9 python3 python ; do
|
||||
if ppath=`which $candidate`; then
|
||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||
# we check that this found executable can actually run
|
||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||
|
||||
python_version=$($ppath -V | awk '{ print $2 }')
|
||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
@@ -13,7 +13,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Union
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
|
||||
SUPPORTED_PYTHON = ">=3.9.0,<=3.11.100"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
||||
|
||||
@@ -141,15 +141,16 @@ class Installer:
|
||||
|
||||
# upgrade pip in Python 3.9 environments
|
||||
if int(platform.python_version_tuple()[1]) == 9:
|
||||
|
||||
from plumbum import FG, local
|
||||
|
||||
pip = local[get_pip_from_venv(venv_dir)]
|
||||
pip[ "install", "--upgrade", "pip"] & FG
|
||||
pip["install", "--upgrade", "pip"] & FG
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
||||
def install(
|
||||
self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None
|
||||
) -> None:
|
||||
"""
|
||||
Install the InvokeAI application into the given runtime path
|
||||
|
||||
@@ -167,7 +168,8 @@ class Installer:
|
||||
|
||||
messages.welcome()
|
||||
|
||||
self.dest = Path(root).expanduser().resolve() if yes_to_all else messages.dest_path(root)
|
||||
default_path = os.environ.get("INVOKEAI_ROOT") or Path(root).expanduser().resolve()
|
||||
self.dest = default_path if yes_to_all else messages.dest_path(root)
|
||||
|
||||
# create the venv for the app
|
||||
self.venv = self.app_venv()
|
||||
@@ -175,7 +177,7 @@ class Installer:
|
||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
self.instance.install(
|
||||
extra_index_url,
|
||||
optional_modules,
|
||||
@@ -188,6 +190,7 @@ class Installer:
|
||||
# run through the configuration flow
|
||||
self.instance.configure()
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||
@@ -196,7 +199,6 @@ class InvokeAiInstance:
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
@@ -247,6 +249,9 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
|
||||
"urllib3~=1.26.0",
|
||||
"requests~=2.28.0",
|
||||
"torch~=2.0.0",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchvision>=0.14.1",
|
||||
@@ -312,7 +317,7 @@ class InvokeAiInstance:
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--use-pep517",
|
||||
str(src)+(optional_modules if optional_modules else ''),
|
||||
str(src) + (optional_modules if optional_modules else ""),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
@@ -329,21 +334,21 @@ class InvokeAiInstance:
|
||||
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1,len(sys.argv)):
|
||||
for i in range(1, len(sys.argv)):
|
||||
el = sys.argv[i]
|
||||
if el in ['-r','--root']:
|
||||
if el in ["-r", "--root"]:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i+1])
|
||||
elif el in ['-y','--yes','--yes-to-all']:
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
new_argv.append(el)
|
||||
sys.argv = new_argv
|
||||
|
||||
|
||||
import requests # to catch download exceptions
|
||||
from messages import introduction
|
||||
|
||||
introduction()
|
||||
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
@@ -353,16 +358,16 @@ class InvokeAiInstance:
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||
except OSError as e:
|
||||
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
|
||||
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
|
||||
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||
finally:
|
||||
if not succeeded:
|
||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
|
||||
print('Alternatively you can relaunch the installer.')
|
||||
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||
print("Alternatively you can relaunch the installer.")
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
@@ -371,11 +376,11 @@ class InvokeAiInstance:
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
#scripts = ['invoke', 'update']
|
||||
scripts = ['invoke']
|
||||
|
||||
# scripts = ['invoke', 'update']
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
|
||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||
dest = self.runtime / f"{script}.{ext}"
|
||||
shutil.copy(src, dest)
|
||||
os.chmod(dest, 0o0755)
|
||||
@@ -420,11 +425,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
# filter out any paths in sys.path that may be system- or user-wide
|
||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||
# temporarily need at install time
|
||||
sys.path = list(filter(
|
||||
lambda p: not p.endswith("-packages")
|
||||
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
|
||||
sys.path
|
||||
))
|
||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||
|
||||
# determine site-packages/lib directory location for the venv
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
@@ -433,7 +434,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_torch_source() -> (Union[str, None],str):
|
||||
def get_torch_source() -> (Union[str, None], str):
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
@@ -454,16 +455,19 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
device = graphical_accelerator()
|
||||
|
||||
url = None
|
||||
optional_modules = None
|
||||
optional_modules = "[onnx]"
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == 'cuda':
|
||||
url = 'https://download.pytorch.org/whl/cu117'
|
||||
optional_modules = '[xformers]'
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ InvokeAI Installer
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from installer import Installer
|
||||
|
||||
@@ -15,7 +16,7 @@ if __name__ == "__main__":
|
||||
dest="root",
|
||||
type=str,
|
||||
help="Destination path for installation",
|
||||
default="~/invokeai",
|
||||
default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-y",
|
||||
@@ -41,7 +42,7 @@ if __name__ == "__main__":
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inst = Installer()
|
||||
|
||||
@@ -36,13 +36,15 @@ else:
|
||||
|
||||
|
||||
def welcome():
|
||||
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) != "":
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@@ -58,6 +60,7 @@ def welcome():
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
@@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
|
||||
dest_confirmed = confirm_install(dest)
|
||||
|
||||
while not dest_confirmed:
|
||||
|
||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||
@@ -165,6 +167,10 @@ def graphical_accelerator():
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||
"cuda",
|
||||
)
|
||||
nvidia_with_dml = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||
"cuda_and_dml",
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||
"rocm",
|
||||
@@ -179,7 +185,7 @@ def graphical_accelerator():
|
||||
)
|
||||
|
||||
if OS == "Windows":
|
||||
options = [nvidia, cpu]
|
||||
options = [nvidia, nvidia_with_dml, cpu]
|
||||
if OS == "Linux":
|
||||
options = [nvidia, amd, cpu]
|
||||
elif OS == "Darwin":
|
||||
@@ -300,15 +306,20 @@ def introduction() -> None:
|
||||
)
|
||||
console.line(2)
|
||||
|
||||
def _platform_specific_help()->str:
|
||||
|
||||
def _platform_specific_help() -> str:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
)
|
||||
elif OS == "Windows":
|
||||
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||
enable long path support on your system.
|
||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
||||
@@ -8,16 +8,13 @@ Preparations:
|
||||
to work. Instructions are given here:
|
||||
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
|
||||
NOTE: At this time we do not recommend Python 3.11. We recommend
|
||||
Version 3.10.9, which has been extensively tested with InvokeAI.
|
||||
|
||||
Before you start the installer, please open up your system's command
|
||||
line window (Terminal or Command) and type the commands:
|
||||
|
||||
python --version
|
||||
|
||||
If all is well, it will print "Python 3.X.X", where the version number
|
||||
is at least 3.9.1, and less than 3.11.
|
||||
is at least 3.9.*, and not higher than 3.11.*.
|
||||
|
||||
If this works, check the version of the Python package manager, pip:
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ IF /I "%choice%" == "1" (
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%choice%" == "7" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --default_only
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
|
||||
@@ -82,7 +82,7 @@ do_choice() {
|
||||
7)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
|
||||
;;
|
||||
8)
|
||||
clear
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
@@ -29,6 +29,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -54,7 +55,7 @@ logger = InvokeAILogger.getLogger()
|
||||
class ApiDependencies:
|
||||
"""Contains and initializes all dependencies for the API"""
|
||||
|
||||
invoker: Invoker = None
|
||||
invoker: Invoker
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||
@@ -67,8 +68,9 @@ class ApiDependencies:
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@@ -78,9 +80,7 @@ class ApiDependencies:
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
@@ -125,12 +125,11 @@ class ApiDependencies:
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.version import __version__
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
|
||||
class LogLevel(int, Enum):
|
||||
NotSet = logging.NOTSET
|
||||
Debug = logging.DEBUG
|
||||
@@ -23,10 +24,12 @@ class LogLevel(int, Enum):
|
||||
Error = logging.ERROR
|
||||
Critical = logging.CRITICAL
|
||||
|
||||
|
||||
class Upscaler(BaseModel):
|
||||
upscaling_method: str = Field(description="Name of upscaling method")
|
||||
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
||||
|
||||
|
||||
|
||||
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
||||
|
||||
|
||||
@@ -45,38 +48,30 @@ class AppConfig(BaseModel):
|
||||
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
|
||||
)
|
||||
@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
|
||||
async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
|
||||
)
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ['tile']
|
||||
infill_methods = ["tile"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append('patchmatch')
|
||||
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
upscaling_models = []
|
||||
for model in typing.get_args(ESRGAN_MODELS):
|
||||
upscaling_models.append(str(Path(model).stem))
|
||||
upscaler = Upscaler(
|
||||
upscaling_method = 'esrgan',
|
||||
upscaling_models = upscaling_models
|
||||
)
|
||||
|
||||
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
||||
|
||||
nsfw_methods = []
|
||||
if SafetyChecker.safety_checker_available():
|
||||
nsfw_methods.append('nsfw_checker')
|
||||
nsfw_methods.append("nsfw_checker")
|
||||
|
||||
watermarking_methods = []
|
||||
if InvisibleWatermark.invisible_watermark_available():
|
||||
watermarking_methods.append('invisible_watermark')
|
||||
|
||||
watermarking_methods.append("invisible_watermark")
|
||||
|
||||
return AppConfig(
|
||||
infill_methods=infill_methods,
|
||||
upscaling_methods=[upscaler],
|
||||
@@ -84,25 +79,26 @@ async def get_config() -> AppConfig:
|
||||
watermarking_methods=watermarking_methods,
|
||||
)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
response_model=LogLevel,
|
||||
)
|
||||
async def get_log_level(
|
||||
) -> LogLevel:
|
||||
async def get_log_level() -> LogLevel:
|
||||
"""Returns the log level"""
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/logging",
|
||||
operation_id="set_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
response_model=LogLevel,
|
||||
)
|
||||
async def set_log_level(
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
) -> LogLevel:
|
||||
"""Sets the log verbosity level"""
|
||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||
|
||||
@@ -1,24 +1,30 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
class AddImagesToBoardResult(BaseModel):
|
||||
board_id: str = Field(description="The id of the board the images were added to")
|
||||
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||
|
||||
|
||||
class RemoveImagesFromBoardResult(BaseModel):
|
||||
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
operation_id="add_image_to_board",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
async def add_image_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
@@ -29,27 +35,78 @@ async def create_board_image(
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
operation_id="remove_image_from_board",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
async def remove_image_from_board(
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/batch",
|
||||
operation_id="add_images_to_board",
|
||||
responses={
|
||||
201: {"description": "Images were added to board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_images_to_board(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
try:
|
||||
added_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
added_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/batch/delete",
|
||||
operation_id="remove_images_from_board",
|
||||
responses={
|
||||
201: {"description": "Images were removed from board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_images_from_board(
|
||||
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
try:
|
||||
removed_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -18,9 +18,7 @@ class DeleteBoardResult(BaseModel):
|
||||
deleted_board_images: list[str] = Field(
|
||||
description="The image names of the board-images relationships that were deleted."
|
||||
)
|
||||
deleted_images: list[str] = Field(
|
||||
description="The names of the images that were deleted."
|
||||
)
|
||||
deleted_images: list[str] = Field(description="The names of the images that were deleted.")
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
@@ -73,22 +71,16 @@ async def update_board(
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete(
|
||||
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
|
||||
)
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
include_images: Optional[bool] = Query(
|
||||
description="Permanently delete all images on the board", default=False
|
||||
),
|
||||
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
||||
) -> DeleteBoardResult:
|
||||
"""Deletes a board"""
|
||||
try:
|
||||
@@ -96,9 +88,7 @@ async def delete_board(
|
||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
return DeleteBoardResult(
|
||||
board_id=board_id,
|
||||
@@ -127,9 +117,7 @@ async def delete_board(
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
|
||||
@@ -1,31 +1,31 @@
|
||||
import io
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
ImageUrlsDTO,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/",
|
||||
"/upload",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully"},
|
||||
@@ -40,15 +40,9 @@ async def upload_image(
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board to add this image to, if any"
|
||||
),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
crop_visible: Optional[bool] = Query(
|
||||
default=False, description="Whether to crop the image"
|
||||
),
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
@@ -83,7 +77,7 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
@@ -109,15 +103,13 @@ async def clear_intermediates() -> int:
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_name}",
|
||||
"/i/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
),
|
||||
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
|
||||
@@ -128,7 +120,7 @@ async def update_image(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}",
|
||||
"/i/{image_name}",
|
||||
operation_id="get_image_dto",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
@@ -144,7 +136,7 @@ async def get_image_dto(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/metadata",
|
||||
"/i/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageMetadata,
|
||||
)
|
||||
@@ -159,8 +151,9 @@ async def get_image_metadata(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/full",
|
||||
@images_router.api_route(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -195,7 +188,7 @@ async def get_image_full(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/thumbnail",
|
||||
"/i/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@@ -212,15 +205,11 @@ async def get_image_thumbnail(
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
response = FileResponse(
|
||||
path, media_type="image/webp", content_disposition_type="inline"
|
||||
)
|
||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
@@ -228,7 +217,7 @@ async def get_image_thumbnail(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_name}/urls",
|
||||
"/i/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
@@ -239,9 +228,7 @@ async def get_image_urls(
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
|
||||
return ImageUrlsDTO(
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
@@ -257,15 +244,9 @@ async def get_image_urls(
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_image_dtos(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list."
|
||||
),
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include."
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images."
|
||||
),
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
@@ -285,3 +266,62 @@ async def list_image_dtos(
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
|
||||
class DeleteImagesFromListResult(BaseModel):
|
||||
deleted_images: list[str]
|
||||
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||
async def delete_images_from_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesFromListResult:
|
||||
try:
|
||||
deleted_images: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
class ImagesUpdatedFromListResult(BaseModel):
|
||||
updated_image_names: list[str] = Field(description="The image names that were updated")
|
||||
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def star_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
|
||||
async def unstar_images_in_list(
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> ImagesUpdatedFromListResult:
|
||||
try:
|
||||
updated_image_names: list[str] = []
|
||||
for image_name in image_names:
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
|
||||
updated_image_names.append(image_name)
|
||||
except:
|
||||
pass
|
||||
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
responses={200: {"model": ModelsList}},
|
||||
)
|
||||
async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
if base_models and len(base_models)>0:
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={200: {"description" : "The model was updated successfully"},
|
||||
400: {"description" : "Bad request"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
409: {"description" : "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = UpdateModelResponse,
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
@@ -81,13 +84,13 @@ async def update_model(
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = info.model_name,
|
||||
new_base = info.base_model,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
@@ -96,16 +99,19 @@ async def update_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get('path')
|
||||
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict()
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
)
|
||||
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -123,49 +129,48 @@ async def update_model(
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
415: {"description" : "Unrecognized file/folder format"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -175,38 +180,34 @@ async def import_model(
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses= {
|
||||
201: {"description" : "The model added successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes = info.dict()
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
)
|
||||
logger.info(f'Successfully added {info.model_name}')
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
@@ -216,66 +217,66 @@ async def add_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: { "description": "Model deleted successfully" },
|
||||
404: { "description": "Model not found" }
|
||||
},
|
||||
status_code = 204,
|
||||
response_model = None,
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: {"description" : "Bad request" },
|
||||
404: { "description": "Model not found" },
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = ConvertModelResponse,
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
convert_dest_directory = dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type)
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
@@ -283,91 +284,101 @@ async def convert_model(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: { "description": "Directory searched successfully" },
|
||||
404: { "description": "Invalid directory path" },
|
||||
200: {"description": "Directory searched successfully"},
|
||||
404: {"description": "Invalid directory path"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||
)->List[pathlib.Path]:
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
responses={
|
||||
200: { "description" : "paths retrieved successfully" },
|
||||
200: {"description": "paths retrieved successfully"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def list_ckpt_configs(
|
||||
)->List[pathlib.Path]:
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
|
||||
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/sync",
|
||||
operation_id="sync_to_config",
|
||||
responses={
|
||||
201: { "description": "synchronization successful" },
|
||||
201: {"description": "synchronization successful"},
|
||||
},
|
||||
status_code = 201,
|
||||
response_model = bool
|
||||
status_code=201,
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config(
|
||||
)->bool:
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: { "description": "Incompatible models" },
|
||||
404: { "description": "One or more models not found" },
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = MergeModelResponse,
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
description="Force merging of models created with different versions of diffusers", default=False
|
||||
),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory = dest
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
)
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
|
||||
@@ -30,9 +30,7 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
},
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(
|
||||
default=None, description="The graph to initialize the session with"
|
||||
)
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
@@ -42,7 +40,7 @@ async def create_session(
|
||||
@session_router.get(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
responses={200: {"model": PaginatedResults[dict]}},
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
@@ -51,13 +49,9 @@ async def list_sessions(
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||
query, page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@@ -91,9 +85,9 @@ async def get_session(
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The node to add"
|
||||
),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@@ -124,9 +118,9 @@ async def add_node(
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The new node"
|
||||
),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@@ -230,7 +224,7 @@ async def delete_edge(
|
||||
try:
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
@@ -255,9 +249,7 @@ async def delete_edge(
|
||||
)
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@@ -274,9 +266,7 @@ async def invoke_session(
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={
|
||||
202: {"description": "The invocation is canceled"}
|
||||
},
|
||||
responses={202: {"description": "The invocation is canceled"}},
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
|
||||
@@ -16,9 +16,7 @@ class SocketIO:
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||
)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
import sys
|
||||
from inspect import signature
|
||||
|
||||
import logging
|
||||
import uvicorn
|
||||
import socket
|
||||
|
||||
@@ -16,9 +17,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
#This should come early so that modules can log their initialization properly
|
||||
# This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
@@ -27,7 +29,7 @@ from invokeai.version.invokeai_version import __version__
|
||||
# we call this early so that the message appears before
|
||||
# other invokeai initialization messages
|
||||
if app_config.version:
|
||||
print(f'InvokeAI version {__version__}')
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
@@ -36,18 +38,19 @@ import mimetypes
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
@@ -57,14 +60,13 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
event_handler_id: int = id(app)
|
||||
app.add_middleware(
|
||||
EventHandlerASGIMiddleware,
|
||||
handlers=[
|
||||
local_handler
|
||||
], # TODO: consider doing this in services to support different configurations
|
||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
@@ -76,9 +78,7 @@ async def startup_event():
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@@ -103,7 +103,8 @@ app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix='/api')
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@@ -133,6 +134,11 @@ def custom_openapi():
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
@@ -144,6 +150,7 @@ def custom_openapi():
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
@@ -166,7 +173,8 @@ def custom_openapi():
|
||||
app.openapi = custom_openapi
|
||||
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
@@ -187,11 +195,8 @@ def overridden_redoc():
|
||||
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/",
|
||||
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||
html=True
|
||||
), name="ui"
|
||||
)
|
||||
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
|
||||
|
||||
|
||||
def invoke_api():
|
||||
def find_port(port: int):
|
||||
@@ -203,19 +208,35 @@ def invoke_api():
|
||||
return find_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
|
||||
|
||||
from invokeai.backend.install.check_root import check_invokeai_root
|
||||
|
||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop=loop,
|
||||
log_level=app_config.log_level,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
l = logging.getLogger(logname)
|
||||
l.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
l.addHandler(ch)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -14,8 +14,14 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
@@ -47,8 +53,8 @@ def add_parsers(
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
||||
):
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
@@ -61,7 +67,7 @@ def add_parsers(
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Convert all fields to arguments
|
||||
fields = command.__fields__ # type: ignore
|
||||
fields = command.__fields__ # type: ignore
|
||||
for name, field in fields.items():
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
@@ -70,13 +76,11 @@ def add_parsers(
|
||||
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
@@ -128,6 +132,7 @@ class CliContext:
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -155,7 +160,7 @@ class BaseCommand(ABC, BaseModel):
|
||||
@classmethod
|
||||
def get_commands_map(cls):
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
||||
|
||||
@abstractmethod
|
||||
def run(self, context: CliContext) -> None:
|
||||
@@ -165,7 +170,8 @@ class BaseCommand(ABC, BaseModel):
|
||||
|
||||
class ExitCommand(BaseCommand):
|
||||
"""Exits the CLI"""
|
||||
type: Literal['exit'] = 'exit'
|
||||
|
||||
type: Literal["exit"] = "exit"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
raise ExitCli()
|
||||
@@ -173,7 +179,8 @@ class ExitCommand(BaseCommand):
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Shows help"""
|
||||
type: Literal['help'] = 'help'
|
||||
|
||||
type: Literal["help"] = "help"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
context.parser.print_help()
|
||||
@@ -183,11 +190,7 @@ def get_graph_execution_history(
|
||||
graph_execution_state: GraphExecutionState,
|
||||
) -> Iterable[str]:
|
||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||
return (
|
||||
n
|
||||
for n in reversed(graph_execution_state.executed_history)
|
||||
if n in graph_execution_state.graph.nodes
|
||||
)
|
||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||
|
||||
|
||||
def get_invocation_command(invocation) -> str:
|
||||
@@ -218,7 +221,8 @@ def get_invocation_command(invocation) -> str:
|
||||
|
||||
class HistoryCommand(BaseCommand):
|
||||
"""Shows the invocation history"""
|
||||
type: Literal['history'] = 'history'
|
||||
|
||||
type: Literal["history"] = "history"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
@@ -235,7 +239,8 @@ class HistoryCommand(BaseCommand):
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
"""Sets a default value for a field"""
|
||||
type: Literal['default'] = 'default'
|
||||
|
||||
type: Literal["default"] = "default"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
@@ -253,7 +258,8 @@ class SetDefaultCommand(BaseCommand):
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
type: Literal["draw_graph"] = "draw_graph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
@@ -271,7 +277,8 @@ class DrawGraphCommand(BaseCommand):
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
type: Literal["draw_xgraph"] = "draw_xgraph"
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
@@ -286,6 +293,7 @@ class DrawExecutionGraphCommand(BaseCommand):
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
|
||||
@@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
|
||||
|
||||
class Completer(object):
|
||||
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
@@ -43,7 +43,7 @@ class Completer(object):
|
||||
except IndexError:
|
||||
pass
|
||||
options = options or list(self.parse_commands().keys())
|
||||
|
||||
|
||||
if not text: # first time
|
||||
self.matches = options
|
||||
else:
|
||||
@@ -56,17 +56,17 @@ class Completer(object):
|
||||
return match
|
||||
|
||||
@classmethod
|
||||
def get_commands(self)->List[object]:
|
||||
def get_commands(self) -> List[object]:
|
||||
"""
|
||||
Return a list of all the client commands and invocations.
|
||||
"""
|
||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||
|
||||
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||
def get_current_command(self, buffer: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse the readline buffer to find the most recent command and its switch.
|
||||
"""
|
||||
if len(buffer)==0:
|
||||
if len(buffer) == 0:
|
||||
return None, None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
@@ -78,11 +78,11 @@ class Completer(object):
|
||||
else:
|
||||
switch = t
|
||||
# don't try to autocomplete switches that are already complete
|
||||
if switch and buffer.endswith(' '):
|
||||
switch=None
|
||||
return command or '', switch or ''
|
||||
if switch and buffer.endswith(" "):
|
||||
switch = None
|
||||
return command or "", switch or ""
|
||||
|
||||
def parse_commands(self)->Dict[str, List[str]]:
|
||||
def parse_commands(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict in which the keys are the command name
|
||||
and the values are the parameters the command takes.
|
||||
@@ -90,11 +90,11 @@ class Completer(object):
|
||||
result = dict()
|
||||
for command in self.commands:
|
||||
hints = get_type_hints(command)
|
||||
name = get_args(hints['type'])[0]
|
||||
result.update({name:hints})
|
||||
name = get_args(hints["type"])[0]
|
||||
result.update({name: hints})
|
||||
return result
|
||||
|
||||
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||
def get_command_options(self, command: str, switch: str) -> List[str]:
|
||||
"""
|
||||
Return all the parameters that can be passed to the command as
|
||||
command-line switches. Returns None if the command is unrecognized.
|
||||
@@ -102,42 +102,46 @@ class Completer(object):
|
||||
parsed_commands = self.parse_commands()
|
||||
if command not in parsed_commands:
|
||||
return None
|
||||
|
||||
|
||||
# handle switches in the format "-foo=bar"
|
||||
argument = None
|
||||
if switch and '=' in switch:
|
||||
switch, argument = switch.split('=')
|
||||
|
||||
parameter = switch.strip('-')
|
||||
if switch and "=" in switch:
|
||||
switch, argument = switch.split("=")
|
||||
|
||||
parameter = switch.strip("-")
|
||||
if parameter in parsed_commands[command]:
|
||||
if argument is None:
|
||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
else:
|
||||
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||
return [
|
||||
f"--{parameter}={x}"
|
||||
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||
]
|
||||
else:
|
||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||
|
||||
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
|
||||
"""
|
||||
Given a parameter type (such as Literal), offers autocompletions.
|
||||
"""
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == 'model':
|
||||
if parameter == "model":
|
||||
return self.manager.model_names()
|
||||
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
|
||||
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
@@ -162,8 +166,6 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic.fields import Field
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
@@ -20,7 +21,7 @@ from invokeai.version.invokeai_version import __version__
|
||||
|
||||
# we call this early so that the message appears before other invokeai initialization messages
|
||||
if config.version:
|
||||
print(f'InvokeAI version {__version__}')
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
@@ -36,18 +37,22 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||
create_system_graphs)
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
||||
GraphInvocation, LibraryGraph,
|
||||
are_connection_types_compatible)
|
||||
from .services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
GraphExecutionState,
|
||||
GraphInvocation,
|
||||
LibraryGraph,
|
||||
are_connection_types_compatible,
|
||||
)
|
||||
from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@@ -58,6 +63,7 @@ from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
import invokeai.backend.util.hotfixes
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
|
||||
@@ -69,6 +75,7 @@ class CliCommand(BaseModel):
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
@@ -113,7 +120,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
class NodeField:
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
@@ -126,15 +133,20 @@ class NodeField():
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
|
||||
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
return NodeField(
|
||||
alias=exposed_input.alias,
|
||||
node_path=f"{node_id}.{exposed_input.node_path}",
|
||||
field=exposed_input.field,
|
||||
field_type=get_type_hints(node_type)[exposed_input.field],
|
||||
)
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
@@ -142,7 +154,12 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
return NodeField(
|
||||
alias=exposed_output.alias,
|
||||
node_path=f"{node_id}.{exposed_output.node_path}",
|
||||
field=exposed_output.field,
|
||||
field_type=get_type_hints(node_output_type)[exposed_output.field],
|
||||
)
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
@@ -165,9 +182,7 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
@@ -179,12 +194,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
matching_fields = [
|
||||
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
|
||||
]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
|
||||
)
|
||||
for alias in matching_fields
|
||||
]
|
||||
@@ -193,6 +210,7 @@ def generate_matching_edges(
|
||||
|
||||
class SessionError(Exception):
|
||||
"""Raised when a session error has occurred"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -209,22 +227,23 @@ def invoke_all(context: CliContext):
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
|
||||
raise SessionError()
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
logger.info(f'InvokeAI version {__version__}')
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
parser.add_argument("commands", nargs="*")
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
sys.stdin = open(infile, "r")
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
@@ -234,13 +253,13 @@ def invoke_cli():
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
@@ -281,24 +300,22 @@ def invoke_cli():
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
@@ -308,7 +325,7 @@ def invoke_cli():
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
re_negid = re.compile("^-[0-9]+$")
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
@@ -318,7 +335,7 @@ def invoke_cli():
|
||||
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
|
||||
while not done:
|
||||
try:
|
||||
if command_line_args_exist:
|
||||
@@ -332,7 +349,7 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
# history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
@@ -353,17 +370,17 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
command: CliCommand = None # type:ignore
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: Optional[LibraryGraph] = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
if args["type"] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
command = CliCommand(command=invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
@@ -385,17 +402,13 @@ def invoke_cli():
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
from_id = (
|
||||
history[0] if current_id == start_id else str(current_id - 1)
|
||||
)
|
||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||
from_node = (
|
||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||
if current_id != start_id
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command, context
|
||||
)
|
||||
matching_edges = generate_matching_edges(from_node, command.command, context)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
@@ -406,16 +419,18 @@ def invoke_cli():
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_edges = generate_matching_edges(link_node, command.command, context)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
edges.extend(matching_edges)
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||
edges = [
|
||||
e
|
||||
for e in edges
|
||||
if e.destination.node_id != command.command.id or e.destination.field != link[2]
|
||||
]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
@@ -428,7 +443,7 @@ def invoke_cli():
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -4,9 +4,5 @@ __all__ = []
|
||||
|
||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
for f in os.listdir(dirname):
|
||||
if (
|
||||
f != "__init__.py"
|
||||
and os.path.isfile("%s/%s" % (dirname, f))
|
||||
and f[-3:] == ".py"
|
||||
):
|
||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||
__all__.append(f[:-3])
|
||||
|
||||
@@ -3,16 +3,366 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
|
||||
get_type_hints)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseConfig, BaseModel, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
core_metadata = "Optional core metadata to be written to image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class UIType(str, Enum):
|
||||
"""
|
||||
Type hints for the UI.
|
||||
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||
on adding a new field type, which involves client-side changes.
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Integer = "integer"
|
||||
Float = "float"
|
||||
Boolean = "boolean"
|
||||
String = "string"
|
||||
Array = "array"
|
||||
Image = "ImageField"
|
||||
Latents = "LatentsField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Color = "ColorField"
|
||||
ImageCollection = "ImageCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
StringCollection = "StringCollection"
|
||||
BooleanCollection = "BooleanCollection"
|
||||
# endregion
|
||||
|
||||
# region Models
|
||||
MainModel = "MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VaeModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
# endregion
|
||||
|
||||
# region Iterate/Collect
|
||||
Collection = "Collection"
|
||||
CollectionItem = "CollectionItem"
|
||||
# endregion
|
||||
|
||||
# region Misc
|
||||
FilePath = "FilePath"
|
||||
Enum = "enum"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are \
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class _InputField(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
"""
|
||||
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
|
||||
|
||||
def InputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
*args: Any,
|
||||
default: Any = Undefined,
|
||||
default_factory: Optional[NoArgAnyCallable] = None,
|
||||
alias: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None,
|
||||
const: Optional[bool] = None,
|
||||
gt: Optional[float] = None,
|
||||
ge: Optional[float] = None,
|
||||
lt: Optional[float] = None,
|
||||
le: Optional[float] = None,
|
||||
multiple_of: Optional[float] = None,
|
||||
allow_inf_nan: Optional[bool] = None,
|
||||
max_digits: Optional[int] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
min_items: Optional[int] = None,
|
||||
max_items: Optional[int] = None,
|
||||
unique_items: Optional[bool] = None,
|
||||
min_length: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
repr: bool = True,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
alias=alias,
|
||||
title=title,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
const=const,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_items=min_items,
|
||||
max_items=max_items,
|
||||
unique_items=unique_items,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
allow_mutation=allow_mutation,
|
||||
regex=regex,
|
||||
discriminator=discriminator,
|
||||
repr=repr,
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
This is used internally by the @tags and @title decorator logic. You probably want to use those
|
||||
decorators, though you may add this class to a node definition to specify the title and tags.
|
||||
"""
|
||||
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
||||
title: Optional[str] = Field(default=None, description="The display name of the node")
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
@@ -40,6 +390,20 @@ class BaseInvocationOutput(BaseModel):
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
"""Raised when an field which requires a connection did not receive a value."""
|
||||
|
||||
def __init__(self, node_id: str, field_name: str):
|
||||
super().__init__(f"Node {node_id} missing connections for field {field_name}")
|
||||
|
||||
|
||||
class MissingInputException(Exception):
|
||||
"""Raised when an field which requires some input, but did not receive a value."""
|
||||
|
||||
def __init__(self, node_id: str, field_name: str):
|
||||
super().__init__(f"Node {node_id} missing value or connection for field {field_name}")
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
@@ -77,70 +441,81 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def get_output_type(cls):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
# fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||
# fmt: on
|
||||
def __init__(self, **data):
|
||||
# nodes may have required fields, that can accept input from connections
|
||||
# on instantiation of the model, we need to exclude these from validation
|
||||
restore = dict()
|
||||
try:
|
||||
field_names = list(self.__fields__.keys())
|
||||
for field_name in field_names:
|
||||
# if the field is required and may get its value from a connection, exclude it from validation
|
||||
field = self.__fields__[field_name]
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if _input in [Input.Connection, Input.Any] and field.required:
|
||||
if field_name not in data:
|
||||
restore[field_name] = self.__fields__.pop(field_name)
|
||||
# instantiate the node, which will validate the data
|
||||
super().__init__(**data)
|
||||
finally:
|
||||
# restore the removed fields
|
||||
for field_name, field in restore.items():
|
||||
self.__fields__[field_name] = field
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.__fields__.items():
|
||||
_input = field.field_info.extra.get("input", None)
|
||||
if field.required and not hasattr(self, field_name):
|
||||
if _input == Input.Connection:
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
|
||||
is_intermediate: bool = InputField(
|
||||
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||
)
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
"control",
|
||||
"image_collection",
|
||||
"vae_model",
|
||||
"lora_model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
title: str
|
||||
T = TypeVar("T", bound=BaseInvocation)
|
||||
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
def title(title: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
|
||||
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.title = title
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class InvocationConfig(BaseConfig):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
def wrapper(cls: Type[T]) -> Type[T]:
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||
cls.UIConfig.tags = list(tags)
|
||||
return cls
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
return wrapper
|
||||
|
||||
@@ -3,64 +3,25 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput, ImageField, IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext, UIConfig)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of floats"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = Field(
|
||||
default=[], description="The float collection")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of images"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = Field(
|
||||
default=[], description="The output images")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "collection"]}
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIType, tags, title
|
||||
|
||||
|
||||
@title("Integer Range")
|
||||
@tags("collection", "integer", "range")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Range",
|
||||
"tags": ["range", "integer", "collection"]
|
||||
},
|
||||
}
|
||||
start: int = InputField(default=0, description="The start of the range")
|
||||
stop: int = InputField(default=10, description="The stop of the range")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
@validator("stop")
|
||||
def stop_gt_start(cls, v, values):
|
||||
@@ -68,94 +29,44 @@ class RangeInvocation(BaseInvocation):
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
@title("Integer Range of Size")
|
||||
@tags("range", "integer", "size", "collection")
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
|
||||
type: Literal["range_of_size"] = "range_of_size"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
size: int = Field(default=1, description="The number of values")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
start: int = InputField(default=0, description="The start of the range")
|
||||
size: int = InputField(default=1, description="The number of values")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Sized Range",
|
||||
"tags": ["range", "integer", "size", "collection"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(
|
||||
collection=list(
|
||||
range(
|
||||
self.start, self.start + self.size,
|
||||
self.step)))
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||
|
||||
|
||||
@title("Random Range")
|
||||
@tags("range", "integer", "random", "collection")
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = InputField(default=1, description="The number of values to generate")
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Random Range",
|
||||
"tags": ["range", "integer", "random", "collection"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(
|
||||
rng.integers(
|
||||
low=self.low, high=self.high,
|
||||
size=self.size)))
|
||||
|
||||
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""Load a collection of images and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
images: list[ImageField] = Field(
|
||||
default=[], description="The image collection to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.images)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"title": "Image Collection",
|
||||
"images": "image_collection",
|
||||
}
|
||||
},
|
||||
}
|
||||
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
||||
@@ -1,130 +1,116 @@
|
||||
from typing import Literal, Optional, Union, List, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import (Blend, Conjunction,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt, Fragment)
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management import ModelType
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management import ModelPatcher, ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .model import ClipField
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(
|
||||
default=None, description="The name of conditioning data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
#type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
ConditioningInfoType = Annotated[
|
||||
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
Field(discriminator="type")
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||
#unconditioned: Optional[torch.Tensor]
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
#class ConditioningAlgo(str, Enum):
|
||||
|
||||
# class ConditioningAlgo(str, Enum):
|
||||
# Compose = "compose"
|
||||
# ComposeEx = "compose_ex"
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
@title("Compel Prompt")
|
||||
@tags("prompt", "compel")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: ClipField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(), context=context,
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(), context=context,
|
||||
**self.clip.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, self.clip.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -139,14 +125,12 @@ class CompelInvocation(BaseInvocation):
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||
prompt)
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(
|
||||
tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get(
|
||||
"cross_attention_control", None),)
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
|
||||
@@ -162,131 +146,93 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(), context=context,
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(), context=context,
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.context.model
|
||||
c = torch.zeros(
|
||||
(1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size),
|
||||
dtype=text_encoder_info.context.cache.precision,
|
||||
)
|
||||
if get_pooled:
|
||||
c_pooled = prompt_embeds[0]
|
||||
c_pooled = torch.zeros(
|
||||
(1, cpu_text_encoder.config.hidden_size),
|
||||
dtype=c.dtype,
|
||||
)
|
||||
else:
|
||||
c_pooled = None
|
||||
c = prompt_embeds.hidden_states[-2]
|
||||
|
||||
del tokenizer
|
||||
del text_encoder
|
||||
del tokenizer_info
|
||||
del text_encoder_info
|
||||
|
||||
c = c.detach().to("cpu")
|
||||
if c_pooled is not None:
|
||||
c_pooled = c_pooled.detach().to("cpu")
|
||||
|
||||
return c, c_pooled, None
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(), context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(), context=context,
|
||||
)
|
||||
return c, c_pooled, None
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
# print(e)
|
||||
#import traceback
|
||||
#print(traceback.format_exc())
|
||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, clip_field.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=True,
|
||||
)
|
||||
|
||||
@@ -320,49 +266,44 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
return c, c_pooled, ec
|
||||
|
||||
|
||||
@title("SDXL Compel Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
@@ -378,47 +319,39 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Compel Prompt")
|
||||
@tags("sdxl", "compel", "prompt")
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
style: str = InputField(
|
||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||
) # TODO: ?
|
||||
original_width: int = InputField(default=1024, description="")
|
||||
original_height: int = InputField(default=1024, description="")
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
@@ -426,7 +359,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -434,127 +367,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||
|
||||
type: Literal["sdxl_raw_prompt"] = "sdxl_raw_prompt"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
style: str = Field(default="", description="Style prompt")
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
target_width: int = Field(1024, description="")
|
||||
target_height: int = Field(1024, description="")
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
clip2: ClipField = Field(None, description="Clip2 to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
target_size = (self.target_height, self.target_width)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + target_size
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=torch.cat([c1, c2], dim=-1),
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec1,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||
|
||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||
original_width: int = Field(1024, description="")
|
||||
original_height: int = Field(1024, description="")
|
||||
crop_top: int = Field(0, description="")
|
||||
crop_left: int = Field(0, description="")
|
||||
aesthetic_score: float = Field(6.0, description="")
|
||||
clip2: ClipField = Field(None, description="Clip to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Prompt (Raw)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
crop_coords = (self.crop_top, self.crop_left)
|
||||
|
||||
add_time_ids = torch.tensor([
|
||||
original_size + crop_coords + (self.aesthetic_score,)
|
||||
])
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
embeds=c2,
|
||||
pooled_embeds=c2_pooled,
|
||||
add_time_ids=add_time_ids,
|
||||
extra_conditioning=ec2, # or None
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return CompelOutput(
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
@@ -563,23 +376,20 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@title("CLIP Skip")
|
||||
@tags("clipskip", "clip", "skip")
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
type: Literal["clip_skip"] = "clip_skip"
|
||||
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "CLIP Skip",
|
||||
"tags": ["clip", "skip"]
|
||||
},
|
||||
}
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
@@ -589,46 +399,26 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False) -> int:
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in blend.prompts
|
||||
]
|
||||
)
|
||||
return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
|
||||
elif type(prompt) is Conjunction:
|
||||
conjunction: Conjunction = prompt
|
||||
return sum(
|
||||
[
|
||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||
for p in conjunction.prompts
|
||||
]
|
||||
)
|
||||
return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(
|
||||
tokenizer, prompt, truncate_if_too_long))
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> List[str]:
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
@@ -639,25 +429,17 @@ def get_tokens_for_prompt_object(
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer, display_label_prefix=None
|
||||
):
|
||||
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts) > 1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(
|
||||
p,
|
||||
tokenizer,
|
||||
display_label_prefix=this_display_label_prefix
|
||||
)
|
||||
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
@@ -694,13 +476,10 @@ def log_tokenization_for_prompt_object(
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_text(
|
||||
text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
|
||||
@@ -6,88 +6,53 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
|
||||
LeresDetector, LineartAnimeDetector,
|
||||
LineartDetector, MediapipeFaceDetector,
|
||||
MidasDetector, MLSDdetector, NormalBaeDetector,
|
||||
OpenposeDetector, PidiNetDetector, SamDetector,
|
||||
ZoeDetector)
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
ContentShuffleDetector,
|
||||
HEDdetector,
|
||||
LeresDetector,
|
||||
LineartAnimeDetector,
|
||||
LineartDetector,
|
||||
MediapipeFaceDetector,
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from ..models.image import ImageOutput, PILInvocationConfig
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||
##############################################
|
||||
"lllyasviel/sd-controlnet-canny",
|
||||
"lllyasviel/sd-controlnet-depth",
|
||||
"lllyasviel/sd-controlnet-hed",
|
||||
"lllyasviel/sd-controlnet-seg",
|
||||
"lllyasviel/sd-controlnet-openpose",
|
||||
"lllyasviel/sd-controlnet-scribble",
|
||||
"lllyasviel/sd-controlnet-normal",
|
||||
"lllyasviel/sd-controlnet-mlsd",
|
||||
|
||||
#############################################
|
||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||
#############################################
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_seg",
|
||||
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_normalbae",
|
||||
"lllyasviel/control_v11p_sd15_scribble",
|
||||
"lllyasviel/control_v11p_sd15_mlsd",
|
||||
"lllyasviel/control_v11p_sd15_softedge",
|
||||
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
"lllyasviel/control_v11p_sd15_lineart",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
# "lllyasviel/control_v11u_sd15_tile",
|
||||
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||
"lllyasviel/control_v11e_sd15_shuffle",
|
||||
"lllyasviel/control_v11e_sd15_ip2p",
|
||||
"lllyasviel/control_v11f1e_sd15_tile",
|
||||
|
||||
#################################################
|
||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||
##################################################
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||
"thibaud/controlnet-sd21-hed-diffusers",
|
||||
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||
"thibaud/controlnet-sd21-color-diffusers",
|
||||
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||
|
||||
##############################################
|
||||
# ControlNetMediaPipeface, ControlNet v1.1
|
||||
##############################################
|
||||
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
|
||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
||||
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
@@ -97,22 +62,17 @@ class ControlNetModelField(BaseModel):
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[ControlNetModelField] = Field(
|
||||
default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(
|
||||
default=1, description="The weight given to the ControlNet")
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||
default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
||||
default="just_resize", description="The resize mode to use")
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
@@ -120,65 +80,45 @@ class ControlField(BaseModel):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError(
|
||||
'Control weights must be within -1 to 2 range')
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
else:
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
"control_model": "controlnet_model",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@title("ControlNet")
|
||||
@tags("controlnet")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
# fmt: off
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "ControlNet",
|
||||
"tags": ["controlnet", "latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(
|
||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
@@ -194,22 +134,13 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_processor"] = "image_processor"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to process")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Processor",
|
||||
"tags": ["image", "processor"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image):
|
||||
# superclass just passes through image without processing
|
||||
@@ -233,7 +164,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
@@ -248,405 +179,319 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class CannyImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Canny Processor")
|
||||
@tags("controlnet", "canny")
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
# fmt: off
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
# Input
|
||||
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Canny Processor",
|
||||
"tags": ["controlnet", "canny", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||
|
||||
# Input
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
high_threshold: int = InputField(
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(
|
||||
image, self.low_threshold, self.high_threshold)
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
class HedImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("HED (softedge) Processor")
|
||||
@tags("controlnet", "hed", "softedge")
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
# fmt: off
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Softedge(HED) Processor",
|
||||
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = hed_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Processor",
|
||||
"tags": ["controlnet", "lineart", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, coarse=self.coarse)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LineartAnimeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies line art anime processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lineart Anime Processor",
|
||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class OpenposeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Openpose processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
# Inputs
|
||||
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Openpose Processor",
|
||||
"tags": ["controlnet", "openpose", "image", "processor"]
|
||||
},
|
||||
}
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
processed_image = hed_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,)
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MidasDepthImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
"""Applies Midas depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
# Inputs
|
||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||
# fmt: on
|
||||
@title("Lineart Processor")
|
||||
@tags("controlnet", "lineart")
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Midas (Depth) Processor",
|
||||
"tags": ["controlnet", "midas", "depth", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = lineart_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@title("Lineart Anime Processor")
|
||||
@tags("controlnet", "lineart", "anime")
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
|
||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@title("Openpose Processor")
|
||||
@tags("controlnet", "openpose", "pose")
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
|
||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||
|
||||
# Inputs
|
||||
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@title("Midas (Depth) Processor")
|
||||
@tags("controlnet", "midas", "depth")
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||
|
||||
# Inputs
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class NormalbaeImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Normal BAE Processor")
|
||||
@tags("controlnet", "normal", "bae")
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Normal BAE Processor",
|
||||
"tags": ["controlnet", "normal", "bae", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MlsdImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("MLSD Processor")
|
||||
@tags("controlnet", "mlsd")
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "MLSD Processor",
|
||||
"tags": ["controlnet", "mlsd", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image):
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, thr_v=self.thr_v,
|
||||
thr_d=self.thr_d)
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class PidiImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("PIDI Processor")
|
||||
@tags("controlnet", "pidi")
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "PIDI Processor",
|
||||
"tags": ["controlnet", "pidi", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
pidi_processor = PidiNetDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image, detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution, safe=self.safe,
|
||||
scribble=self.scribble)
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class ContentShuffleImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Content Shuffle Processor")
|
||||
@tags("controlnet", "contentshuffle")
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Content Shuffle Processor",
|
||||
"tags": ["controlnet", "contentshuffle", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||
|
||||
# Inputs
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f
|
||||
)
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
h=self.h,
|
||||
w=self.w,
|
||||
f=self.f,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||
class ZoeDepthImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Zoe (Depth) Processor")
|
||||
@tags("controlnet", "zoe", "depth")
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Zoe (Depth) Processor",
|
||||
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||
|
||||
def run_processor(self, image):
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained(
|
||||
"lllyasviel/Annotators")
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class MediapipeFaceProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Mediapipe Face Processor")
|
||||
@tags("controlnet", "mediapipe", "face")
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
# Inputs
|
||||
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Mediapipe Processor",
|
||||
"tags": ["controlnet", "mediapipe", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||
|
||||
# Inputs
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
|
||||
def run_processor(self, image):
|
||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||
# so convert to RGB if needed
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
|
||||
|
||||
class LeresImageProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Leres (Depth) Processor")
|
||||
@tags("controlnet", "leres", "depth")
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||
# Inputs
|
||||
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = Field(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Leres (Depth) Processor",
|
||||
"tags": ["controlnet", "leres", "depth", "image", "processor"]
|
||||
},
|
||||
}
|
||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||
|
||||
# Inputs
|
||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost,
|
||||
image,
|
||||
thr_a=self.thr_a,
|
||||
thr_b=self.thr_b,
|
||||
boost=self.boost,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution)
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
class TileResamplerProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Tile Resample Processor")
|
||||
@tags("controlnet", "tile")
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||
# Inputs
|
||||
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Resample Processor",
|
||||
"tags": ["controlnet", "tile", "resample", "image", "processor"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||
|
||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||
def tile_resample(self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
def tile_resample(
|
||||
self,
|
||||
np_img: np.ndarray,
|
||||
res=512, # never used?
|
||||
down_sampling_rate=1.0,
|
||||
):
|
||||
np_img = HWC3(np_img)
|
||||
if down_sampling_rate < 1.1:
|
||||
return np_img
|
||||
@@ -658,36 +503,33 @@ class TileResamplerProcessorInvocation(
|
||||
|
||||
def run_processor(self, img):
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate
|
||||
)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
down_sampling_rate=self.down_sampling_rate,
|
||||
)
|
||||
processed_image = Image.fromarray(processed_np_image)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SegmentAnythingProcessorInvocation(
|
||||
ImageProcessorInvocation, PILInvocationConfig):
|
||||
@title("Segment Anything Processor")
|
||||
@tags("controlnet", "segmentanything")
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
# fmt: off
|
||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [
|
||||
"controlnet", "segment", "anything", "sam", "image", "processor"]}, }
|
||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||
|
||||
def run_processor(self, image):
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints")
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(np_img)
|
||||
return processed_image
|
||||
|
||||
|
||||
class SamDetectorReproducibleColors(SamDetector):
|
||||
|
||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||
# base class show_anns() method randomizes colors,
|
||||
# which seems to also lead to non-reproducible image generation
|
||||
@@ -695,19 +537,15 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
def show_anns(self, anns: List[Dict]):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
h, w = anns[0]['segmentation'].shape
|
||||
final_img = Image.fromarray(
|
||||
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||
h, w = anns[0]["segmentation"].shape
|
||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||
palette = ade_palette()
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
m = ann['segmentation']
|
||||
m = ann["segmentation"]
|
||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||
ann_color = palette[i % len(palette)]
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(
|
||||
Image.fromarray(img, mode="RGB"),
|
||||
(0, 0),
|
||||
Image.fromarray(np.uint8(m * 255)))
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
@@ -5,43 +5,22 @@ from typing import Literal
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
@title("OpenCV Inpaint")
|
||||
@tags("opencv", "inpaint")
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "OpenCV Inpaint",
|
||||
"tags": ["opencv", "inpaint"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
ResourceOrigin)
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
from .latent import get_scheduler
|
||||
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
context: OldModelContext
|
||||
|
||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.context = OldModelContext(
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(BaseInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The input image")
|
||||
strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||
)
|
||||
fit: bool = Field(
|
||||
default=True,
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
# Inputs
|
||||
mask: Optional[ImageField] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(
|
||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||
)
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(
|
||||
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||
)
|
||||
tile_size: int = Field(
|
||||
default=32, ge=1, description="The tile infill method size (px)"
|
||||
)
|
||||
infill_method: INFILL_METHODS = Field(
|
||||
default=DEFAULT_INFILL_METHOD,
|
||||
description="The method used to infill empty regions (px)",
|
||||
)
|
||||
inpaint_width: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the inpaint region (px)",
|
||||
)
|
||||
inpaint_height: Optional[int] = Field(
|
||||
default=None,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the inpaint region (px)",
|
||||
)
|
||||
inpaint_fill: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The solid infill method color",
|
||||
)
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"title": "Inpaint"
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context, unet):
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context,)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
||||
|
||||
with vae_info as vae,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
conditioning = self.get_conditioning(context, model.context.model.unet)
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -1,72 +1,31 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from ..models.image import (
|
||||
ImageCategory, ImageField, ResourceOrigin,
|
||||
PILInvocationConfig, ImageOutput, MaskOutput,
|
||||
)
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Load Image",
|
||||
"tags": ["image", "load"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Show Image")
|
||||
@tags("image")
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
|
||||
# Metadata
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Show Image",
|
||||
"tags": ["image", "show"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to show")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -82,34 +41,25 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Crop Image")
|
||||
@tags("image", "crop")
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Crop Image",
|
||||
"tags": ["image", "crop"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
)
|
||||
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
|
||||
image_crop.paste(image, (-self.x, -self.y))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
@@ -128,38 +78,31 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Paste Image")
|
||||
@tags("image", "paste")
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
|
||||
# Inputs
|
||||
base_image: Optional[ImageField] = Field(default=None, description="The base image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Paste Image",
|
||||
"tags": ["image", "paste"]
|
||||
},
|
||||
}
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
image: ImageField = InputField(description="The image to paste")
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description="The mask to use when pasting",
|
||||
)
|
||||
x: int = InputField(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
if self.mask is not None:
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
mask = ImageOps.invert(mask.convert("L"))
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
|
||||
min_x = min(0, self.x)
|
||||
@@ -167,9 +110,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
max_x = max(base_image.width, image.width + self.x)
|
||||
max_y = max(base_image.height, image.height + self.y)
|
||||
|
||||
new_image = Image.new(
|
||||
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
||||
)
|
||||
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||
|
||||
@@ -189,26 +130,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Mask from Alpha")
|
||||
@tags("image", "mask")
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Mask From Alpha",
|
||||
"tags": ["image", "mask", "alpha"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
@@ -224,31 +158,24 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Multiply Images")
|
||||
@tags("image", "multiply")
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
|
||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Multiply Images",
|
||||
"tags": ["image", "multiply"]
|
||||
},
|
||||
}
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
image2: ImageField = InputField(description="The second image to multiply")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
@@ -275,24 +202,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Extract Image Channel")
|
||||
@tags("image", "channel")
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Channel",
|
||||
"tags": ["image", "channel"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -318,24 +238,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Convert Image Mode")
|
||||
@tags("image", "convert")
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Convert Image",
|
||||
"tags": ["image", "convert"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -357,33 +270,26 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
@title("Blur Image")
|
||||
@tags("image", "blur")
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Blur Image",
|
||||
"tags": ["image", "blur"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
||||
# Metadata
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
if self.blur_type == "gaussian"
|
||||
else ImageFilter.BoxBlur(self.radius)
|
||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||
)
|
||||
blur_image = image.filter(blur)
|
||||
|
||||
@@ -423,26 +329,19 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Resize Image")
|
||||
@tags("image", "resize")
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to resize")
|
||||
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Resize Image",
|
||||
"tags": ["image", "resize"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -470,25 +369,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Scale Image")
|
||||
@tags("image", "scale")
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Scale Image",
|
||||
"tags": ["image", "scale"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
scale_factor: float = InputField(
|
||||
default=2.0,
|
||||
gt=0,
|
||||
description="The factor by which to scale the image",
|
||||
)
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -518,31 +414,24 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@title("Lerp Image")
|
||||
@tags("image", "lerp")
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "lerp"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
image_arr = image_arr * (self.max - self.min) + self.min
|
||||
|
||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@@ -561,36 +450,25 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
@title("Inverse Lerp Image")
|
||||
@tags("image", "ilerp")
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image Inverse Linear Interpolation",
|
||||
"tags": ["image", "linear", "interpolation", "inverse"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
numpy.minimum(
|
||||
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
||||
)
|
||||
* 255
|
||||
)
|
||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||
|
||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||
|
||||
@@ -609,35 +487,31 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
|
||||
@title("Blur NSFW Image")
|
||||
@tags("image", "nsfw")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
# fmt: off
|
||||
# Metadata
|
||||
type: Literal["img_nsfw"] = "img_nsfw"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Blur NSFW Images",
|
||||
"tags": ["image", "nsfw", "checker"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
|
||||
logger = context.services.logger
|
||||
logger.debug("Running NSFW checker")
|
||||
if SafetyChecker.has_nsfw_concept(image):
|
||||
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
caution = self._get_caution_img()
|
||||
blurry_image.paste(caution,(0,0),caution)
|
||||
blurry_image.paste(caution, (0, 0), caution)
|
||||
image = blurry_image
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
@@ -649,37 +523,34 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
def _get_caution_img(self)->Image:
|
||||
|
||||
def _get_caution_img(self) -> Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png')
|
||||
return caution.resize((caution.width // 2, caution.height //2))
|
||||
|
||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
""" Add an invisible watermark to an image """
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
# fmt: off
|
||||
|
||||
@title("Add Invisible Watermark")
|
||||
@tags("image", "watermark")
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
# Metadata
|
||||
type: Literal["img_watermark"] = "img_watermark"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Add Invisible Watermark",
|
||||
"tags": ["image", "watermark", "invisible"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -701,4 +572,332 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
|
||||
@title("Mask Edge")
|
||||
@tags("image", "mask", "inpaint")
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
type: Literal["mask_edge"] = "mask_edge"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
edge_size: int = InputField(description="The size of the edge")
|
||||
edge_blur: int = InputField(description="The amount of blur on the edge")
|
||||
low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection")
|
||||
high_threshold: int = InputField(
|
||||
description="Second threshold for the hysteresis procedure in Canny edge detection"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||
npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
|
||||
npedge = cv2.Canny(npimg, threshold1=self.low_threshold, threshold2=self.high_threshold)
|
||||
npmask = npgradient + npedge
|
||||
npmask = cv2.dilate(npmask, numpy.ones((3, 3), numpy.uint8), iterations=int(self.edge_size / 2))
|
||||
|
||||
new_mask = Image.fromarray(npmask)
|
||||
|
||||
if self.edge_blur > 0:
|
||||
new_mask = new_mask.filter(ImageFilter.BoxBlur(self.edge_blur))
|
||||
|
||||
new_mask = ImageOps.invert(new_mask)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=new_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Combine Mask")
|
||||
@tags("image", "mask", "multiply")
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
type: Literal["mask_combine"] = "mask_combine"
|
||||
|
||||
# Inputs
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
mask2: ImageField = InputField(description="The second image to combine")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L")
|
||||
mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L")
|
||||
|
||||
combined_mask = ImageChops.multiply(mask1, mask2)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=combined_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Color Correct")
|
||||
@tags("image", "color")
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
"""
|
||||
|
||||
type: Literal["color_correct"] = "color_correct"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to color-correct")
|
||||
reference: ImageField = InputField(description="Reference image for color-correction")
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_init_mask = None
|
||||
if self.mask is not None:
|
||||
pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L")
|
||||
|
||||
init_image = context.services.images.get_pil_image(self.reference.image_name)
|
||||
|
||||
result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||
|
||||
# if init_image is None or init_mask is None:
|
||||
# return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
# pil_init_mask = (
|
||||
# init_mask.getchannel("A")
|
||||
# if init_mask.mode == "RGBA"
|
||||
# else init_mask.convert("L")
|
||||
# )
|
||||
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = numpy.asarray(init_image.convert("RGB"), dtype=numpy.uint8)
|
||||
init_a_pixels = numpy.asarray(pil_init_image.getchannel("A"), dtype=numpy.uint8)
|
||||
init_mask_pixels = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = numpy.asarray(result.convert("RGB"), dtype=numpy.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
if np_init_rgb_pixels_masked.size > 0:
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:, :, :] = (
|
||||
(
|
||||
(
|
||||
(np_matched_result[:, :, :].astype(numpy.float32) - gen_means[None, None, :])
|
||||
/ gen_std[None, None, :]
|
||||
)
|
||||
* init_std[None, None, :]
|
||||
+ init_means[None, None, :]
|
||||
)
|
||||
.clip(0, 255)
|
||||
.astype(numpy.uint8)
|
||||
)
|
||||
matched_result = Image.fromarray(np_matched_result, mode="RGB")
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode="RGB")
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if self.mask_blur_radius > 0:
|
||||
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||
nmd = cv2.erode(
|
||||
nm,
|
||||
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
|
||||
iterations=int(self.mask_blur_radius / 2),
|
||||
)
|
||||
pmd = Image.fromarray(nmd, mode="L")
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(self.mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, result.split()[-1])
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=matched_result,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Hue Adjustment")
|
||||
@tags("image", "hue", "hsl")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = numpy.array(pil_image.convert("HSV"))
|
||||
|
||||
# Convert hue from 0..360 to 0..256
|
||||
hue = int(256 * ((self.hue % 360) / 360))
|
||||
|
||||
# Increment each hue and wrap around at 255
|
||||
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Luminosity Adjustment")
|
||||
@tags("image", "luminosity", "hsl")
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
|
||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
luminosity: float = InputField(
|
||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the luminosity (value)
|
||||
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Saturation Adjustment")
|
||||
@tags("image", "saturation", "hsl")
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
|
||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the saturation
|
||||
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
)
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
@@ -30,9 +25,7 @@ def infill_methods() -> list[str]:
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
@@ -44,9 +37,7 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
@@ -68,9 +59,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||
) -> Image.Image:
|
||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@@ -103,9 +92,7 @@ def tile_fill_missing(
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
@@ -122,26 +109,20 @@ def tile_fill_missing(
|
||||
return si
|
||||
|
||||
|
||||
@title("Solid Color Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
color: ColorField = Field(
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
color: ColorField = InputField(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Color Infill",
|
||||
"tags": ["image", "inpaint", "color", "infill"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -166,36 +147,27 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("Tile Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
# Input
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Tile Infill",
|
||||
"tags": ["image", "inpaint", "tile", "infill"]
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
@@ -214,22 +186,15 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@title("PatchMatch Infill")
|
||||
@tags("image", "inpaint")
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Patch Match Infill",
|
||||
"tags": ["image", "inpaint", "patchmatch", "infill"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
@@ -5,89 +5,76 @@ from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
import torchvision.transforms as T
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models.base import ModelType
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .image import ImageOutput
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
ControlNetData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(
|
||||
default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||
scheduler_name, SCHEDULER_MAP['ddim']
|
||||
)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict(), context=context,
|
||||
**scheduler_info.dict(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
@@ -99,33 +86,50 @@ def get_scheduler(
|
||||
**scheduler_extra_config,
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||
if scheduler_class is DPMSolverSDEScheduler:
|
||||
scheduler_config["noise_sampler_seed"] = seed
|
||||
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
@title("Denoise Latents")
|
||||
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
type: Literal["denoise_latents"] = "denoise_latents"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
@@ -133,39 +137,26 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
@@ -173,13 +164,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@@ -190,18 +182,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None # v_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
scheduler,
|
||||
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
# flip all bits to have noise different from initial
|
||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@@ -234,7 +225,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if unet.dtype == torch.float16 else "float32",
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
@@ -247,7 +237,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
@@ -261,7 +250,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_list = control_input
|
||||
else:
|
||||
control_list = None
|
||||
if (control_list is None):
|
||||
if control_list is None:
|
||||
control_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
@@ -281,9 +270,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(
|
||||
control_image_field.image_name
|
||||
)
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@@ -316,210 +303,188 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
num_inference_steps = steps
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(num_inference_steps, device="cpu")
|
||||
timesteps = scheduler.timesteps.to(device=device)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
# apply denoising_start
|
||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps)))
|
||||
timesteps = timesteps[t_start_idx:]
|
||||
if scheduler.order == 2 and t_start_idx > 0:
|
||||
timesteps = timesteps[1:]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
# save start timestep to apply noise
|
||||
init_timestep = timesteps[:1]
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
# apply denoising_end
|
||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps)))
|
||||
if scheduler.order == 2 and t_end_idx > 0:
|
||||
t_end_idx += 1
|
||||
timesteps = timesteps[:t_end_idx]
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
# calculate step count based on scheduler order
|
||||
num_inference_steps = len(timesteps)
|
||||
if scheduler.order == 2:
|
||||
num_inference_steps += num_inference_steps % 2
|
||||
num_inference_steps = num_inference_steps // 2
|
||||
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
def prep_mask_tensor(self, mask, context, lantents):
|
||||
if mask is None:
|
||||
return None
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to use as a base image")
|
||||
strength: float = Field(
|
||||
default=0.7, ge=0, le=1,
|
||||
description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latent To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
mask_image = context.services.images.get_pil_image(mask.image_name)
|
||||
if mask_image.mode != "L":
|
||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
return 1 - mask_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if self.latents is not None:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
else:
|
||||
latents = torch.zeros_like(noise)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}), context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
mask = self.prep_mask_tensor(self.mask, context, latents)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
if mask is not None:
|
||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=latents.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=unet.device,
|
||||
)
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise,
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@title("Latents to Image")
|
||||
@tags("latents", "image", "vae")
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Decode latents by overlapping tiles(less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Latents To Image",
|
||||
"tags": ["latents", "image"],
|
||||
},
|
||||
}
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(), context=context,
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with vae_info as vae:
|
||||
@@ -586,46 +551,45 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
|
||||
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@title("Resize Latents")
|
||||
@tags("latents", "resize")
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to resize")
|
||||
width: Union[int, None] = Field(default=512,
|
||||
ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: Union[int, None] = Field(default=512,
|
||||
ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Resize Latents",
|
||||
"tags": ["latents", "resize"]
|
||||
},
|
||||
}
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
width: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device=choose_torch_device()
|
||||
device = choose_torch_device()
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents.to(device), size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
latents.to(device),
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
@@ -635,44 +599,37 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@title("Scale Latents")
|
||||
@tags("latents", "resize")
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(
|
||||
description="The latents to scale")
|
||||
scale_factor: float = Field(
|
||||
gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||
default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Scale Latents",
|
||||
"tags": ["latents", "scale"]
|
||||
},
|
||||
}
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device=choose_torch_device()
|
||||
device = choose_torch_device()
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
latents.to(device),
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
@@ -682,31 +639,26 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@title("Image to Latents")
|
||||
@tags("latents", "image", "vae")
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The image to encode")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
description="Encode latents by overlaping tiles(less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Image To Latents",
|
||||
"tags": ["latents", "image"]
|
||||
},
|
||||
}
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -715,9 +667,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# )
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(), context=context,
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
@@ -744,12 +697,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
vae.post_quant_conv.to(orig_dtype)
|
||||
vae.decoder.conv_in.to(orig_dtype)
|
||||
vae.decoder.mid_block.to(orig_dtype)
|
||||
#else:
|
||||
# else:
|
||||
# latents = latents.float()
|
||||
|
||||
else:
|
||||
vae.to(dtype=torch.float16)
|
||||
#latents = latents.half()
|
||||
# latents = latents.half()
|
||||
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
@@ -760,9 +713,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
@@ -770,4 +721,4 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@@ -2,149 +2,83 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import IntegerOutput
|
||||
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""A float output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["float_output"] = "float_output"
|
||||
param: float = Field(default=None, description="The output float")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Add Integers")
|
||||
@tags("math")
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Add",
|
||||
"tags": ["math", "add"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Subtract Integers")
|
||||
@tags("math")
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Subtract",
|
||||
"tags": ["math", "subtract"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Multiply Integers")
|
||||
@tags("math")
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Multiply",
|
||||
"tags": ["math", "multiply"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
@title("Divide Integers")
|
||||
@tags("math")
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Divide",
|
||||
"tags": ["math", "divide"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
@title("Random Integer")
|
||||
@tags("math")
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Random Integer",
|
||||
"tags": ["math", "random", "integer"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=np.random.randint(self.low, self.high))
|
||||
|
||||
@@ -1,26 +1,34 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
from ...version import __version__
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModelExcludeNull):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
|
||||
class CoreMetadata(BaseModel):
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
@@ -37,64 +45,49 @@ class CoreMetadata(BaseModel):
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(
|
||||
description="The ControlNets used for inference"
|
||||
)
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Union[float, None] = Field(
|
||||
strength: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
)
|
||||
init_image: Optional[str] = Field(default=None, description="The name of the initial image")
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The positive style prompt parameter"
|
||||
)
|
||||
negative_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The negative style prompt parameter"
|
||||
)
|
||||
positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter")
|
||||
negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter")
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(
|
||||
default=None, description="The SDXL Refiner model used"
|
||||
)
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used")
|
||||
refiner_cfg_scale: Optional[float] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(
|
||||
default=None, description="The number of steps used for the refiner"
|
||||
)
|
||||
refiner_scheduler: Union[str, None] = Field(
|
||||
default=None, description="The scheduler used for the refiner"
|
||||
)
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(
|
||||
default=None, description="The start value used for refiner denoising"
|
||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
class ImageMetadata(BaseModelExcludeNull):
|
||||
"""An image's generation metadata"""
|
||||
|
||||
metadata: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||
)
|
||||
graph: Optional[dict] = Field(
|
||||
default=None, description="The graph that created the image"
|
||||
)
|
||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
||||
|
||||
|
||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
@@ -102,82 +95,86 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
||||
|
||||
metadata: CoreMetadata = Field(description="The core metadata for the image")
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@title("Metadata Accumulator")
|
||||
@tags("metadata")
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||
|
||||
generation_mode: str = Field(
|
||||
generation_mode: str = InputField(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
height: int = Field(description="The height parameter")
|
||||
seed: int = Field(description="The seed used for noise generation")
|
||||
rand_device: str = Field(description="The device used for random number generation")
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(
|
||||
positive_prompt: str = InputField(description="The positive prompt parameter")
|
||||
negative_prompt: str = InputField(description="The negative prompt parameter")
|
||||
width: int = InputField(description="The width parameter")
|
||||
height: int = InputField(description="The height parameter")
|
||||
seed: int = InputField(description="The seed used for noise generation")
|
||||
rand_device: str = InputField(description="The device used for random number generation")
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: int = InputField(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(
|
||||
description="The ControlNets used for inference"
|
||||
)
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
)
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
init_image: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The positive style prompt parameter"
|
||||
positive_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The positive style prompt parameter",
|
||||
)
|
||||
negative_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The negative style prompt parameter"
|
||||
negative_style_prompt: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The negative style prompt parameter",
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(
|
||||
default=None, description="The SDXL Refiner model used"
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
refiner_cfg_scale: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(
|
||||
default=None, description="The number of steps used for the refiner"
|
||||
refiner_steps: Optional[int] = InputField(
|
||||
default=None,
|
||||
description="The number of steps used for the refiner",
|
||||
)
|
||||
refiner_scheduler: Union[str, None] = Field(
|
||||
default=None, description="The scheduler used for the refiner"
|
||||
refiner_scheduler: Optional[str] = InputField(
|
||||
default=None,
|
||||
description="The scheduler used for the refiner",
|
||||
)
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(
|
||||
default=None, description="The start value used for refiner denoising"
|
||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_start: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The start value used for refiner denoising",
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
@@ -4,17 +4,25 @@ from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(
|
||||
default=None, description="Info to load submodel"
|
||||
)
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
@@ -33,6 +41,7 @@ class ClipField(BaseModel):
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
@@ -41,19 +50,19 @@ class VaeField(BaseModel):
|
||||
class ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output"] = "model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
@@ -62,24 +71,18 @@ class LoRAModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@title("Main Model Loader")
|
||||
@tags("model")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["main_model_loader"] = "main_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Model Loader",
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@@ -154,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
@@ -180,39 +167,34 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("LoRA Loader")
|
||||
@tags("lora", "model")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["lora_loader"] = "lora_loader"
|
||||
|
||||
lora: Union[LoRAModelField, None] = Field(
|
||||
default=None, description="Lora model name"
|
||||
# Inputs
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP"
|
||||
)
|
||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||
|
||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Lora Loader",
|
||||
"tags": ["lora", "loader"],
|
||||
"type_hints": {"lora": "lora_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
@@ -228,14 +210,10 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(
|
||||
lora.model_name == lora_name for lora in self.unet.loras
|
||||
):
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(
|
||||
lora.model_name == lora_name for lora in self.clip.loras
|
||||
):
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
@@ -267,6 +245,101 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@title("SDXL LoRA Loader")
|
||||
@tags("sdxl", "lora", "model")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = Field(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||
)
|
||||
clip: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
)
|
||||
clip2: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
@@ -277,29 +350,23 @@ class VAEModelField(BaseModel):
|
||||
class VaeLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
||||
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
# fmt: on
|
||||
# Outputs
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("VAE Loader")
|
||||
@tags("vae", "model")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
type: Literal["vae_loader"] = "vae_loader"
|
||||
|
||||
vae_model: VAEModelField = Field(description="The VAE to load")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "VAE Loader",
|
||||
"tags": ["vae", "loader"],
|
||||
"type_hints": {"vae_model": "vae_model"},
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, validator
|
||||
import torch
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from pydantic import validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -61,62 +66,53 @@ Nodes
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
# fmt: on
|
||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
@title("Noise")
|
||||
@tags("latents", "noise")
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(
|
||||
seed: int = InputField(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use",
|
||||
description=FieldDescriptions.seed,
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
width: int = Field(
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The width of the resulting noise",
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = Field(
|
||||
height: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description="The height of the resulting noise",
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
use_cpu: bool = Field(
|
||||
use_cpu: bool = InputField(
|
||||
default=True,
|
||||
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Noise",
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
@@ -132,4 +128,4 @@ class NoiseInvocation(BaseInvocation):
|
||||
)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||
|
||||
512
invokeai/app/invocations/onnx.py
Normal file
512
invokeai/app/invocations/onnx.py
Normal file
@@ -0,0 +1,512 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
|
||||
|
||||
@title("ONNX Prompt (Raw)")
|
||||
@tags("onnx", "prompt")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
||||
orig_tokenizer, text_encoder, ti_list
|
||||
) as (tokenizer, ti_manager):
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||
text_inputs = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
"""
|
||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
"""
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
@title("ONNX Text to Latents")
|
||||
@tags("latents", "inference", "txt2img", "onnx")
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||
|
||||
# Inputs
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
noise: LatentsField = InputField(
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
ui_type=UIType.Control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if isinstance(c, torch.Tensor):
|
||||
c = c.cpu().numpy()
|
||||
if isinstance(uc, torch.Tensor):
|
||||
uc = uc.cpu().numpy()
|
||||
device = torch.device(choose_torch_device())
|
||||
prompt_embeds = np.concatenate([uc, c])
|
||||
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
if isinstance(latents, torch.Tensor):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
# latents_dtype = prompt_embeds.dtype
|
||||
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
# if latents.shape != latents_shape:
|
||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=0, # TODO: refactor this node
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
return latent.cpu().numpy()
|
||||
|
||||
def numpy2torch(latent, device):
|
||||
return torch.from_numpy(latent).to(device)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet, ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
if loras:
|
||||
unet.release_session()
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
_, _, h, w = latents.shape
|
||||
unet.create_session(h, w)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
for i in tqdm(range(len(scheduler.timesteps))):
|
||||
t = scheduler.timesteps[i]
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler.step(
|
||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||
)
|
||||
latents = torch2numpy(scheduler_output.prev_sample)
|
||||
|
||||
state = PipelineIntermediateState(
|
||||
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
||||
)
|
||||
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
||||
|
||||
# call the callback, if provided
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||
|
||||
|
||||
# Latent to image
|
||||
@title("ONNX Latents to Image")
|
||||
@tags("latents", "image", "vae", "onnx")
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.denoised_latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
metadata: Optional[CoreMetadata] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with vae_info as vae:
|
||||
vae.create_session()
|
||||
|
||||
# copied from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||
|
||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@title("ONNX Model Loader")
|
||||
@tags("onnx", "model")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
# Inputs
|
||||
model: OnnxModelField = InputField(
|
||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -1,61 +1,68 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
from typing import Literal, Optional
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from easing_functions import (
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseInOut,
|
||||
CubicEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseOut,
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
QuadEaseIn,
|
||||
QuadEaseInOut,
|
||||
QuadEaseOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseOut,
|
||||
SineEaseIn,
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||
|
||||
|
||||
@title("Float Range")
|
||||
@tags("math", "range")
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Linear Range (Float)",
|
||||
"tags": ["math", "float", "linear", "range"]
|
||||
},
|
||||
}
|
||||
start: float = InputField(default=5, description="The first value of the range")
|
||||
stop: float = InputField(default=10, description="The last value of the range")
|
||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
@@ -92,43 +99,32 @@ EASING_FUNCTIONS_MAP = {
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@title("Step Param Easing")
|
||||
@tags("step", "easing")
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Param Easing By Step",
|
||||
"tags": ["param", "step", "easing"]
|
||||
},
|
||||
}
|
||||
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
@@ -170,12 +166,13 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
||||
)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
@@ -214,9 +211,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
@@ -240,13 +235,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .math import FloatOutput, IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "integer"],
|
||||
"title": "Integer Parameter"
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
|
||||
class ParamFloatInvocation(BaseInvocation):
|
||||
"""A float parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_float"] = "param_float"
|
||||
param: float = Field(default=0.0, description="The float value")
|
||||
#fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "float"],
|
||||
"title": "Float Parameter"
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(param=self.param)
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""A string output"""
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = Field(default=None, description="The output string")
|
||||
|
||||
|
||||
class ParamStringInvocation(BaseInvocation):
|
||||
"""A string parameter"""
|
||||
type: Literal['param_string'] = 'param_string'
|
||||
text: str = Field(default='', description='The string value')
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["param", "string"],
|
||||
"title": "String Parameter"
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
494
invokeai/app/invocations/primitives.py
Normal file
494
invokeai/app/invocations/primitives.py
Normal file
@@ -0,0 +1,494 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
from anyio import Condition
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
|
||||
"""
|
||||
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
||||
- primitive nodes
|
||||
- primitive outputs
|
||||
- primitive collection outputs
|
||||
"""
|
||||
|
||||
# region Boolean
|
||||
|
||||
|
||||
class BooleanOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single boolean"""
|
||||
|
||||
type: Literal["boolean_output"] = "boolean_output"
|
||||
a: bool = OutputField(description="The output boolean")
|
||||
|
||||
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of booleans"""
|
||||
|
||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[bool] = OutputField(
|
||||
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Boolean Primitive")
|
||||
@tags("primitives", "boolean")
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
|
||||
type: Literal["boolean"] = "boolean"
|
||||
|
||||
# Inputs
|
||||
a: bool = InputField(default=False, description="The boolean value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||
return BooleanOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Boolean Primitive Collection")
|
||||
@tags("primitives", "boolean", "collection")
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
type: Literal["boolean_collection"] = "boolean_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[bool] = InputField(
|
||||
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Integer
|
||||
|
||||
|
||||
class IntegerOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single integer"""
|
||||
|
||||
type: Literal["integer_output"] = "integer_output"
|
||||
a: int = OutputField(description="The output integer")
|
||||
|
||||
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of integers"""
|
||||
|
||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = OutputField(
|
||||
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Integer Primitive")
|
||||
@tags("primitives", "integer")
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
|
||||
type: Literal["integer"] = "integer"
|
||||
|
||||
# Inputs
|
||||
a: int = InputField(default=0, description="The integer value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(a=self.a)
|
||||
|
||||
|
||||
@title("Integer Primitive Collection")
|
||||
@tags("primitives", "integer", "collection")
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
type: Literal["integer_collection"] = "integer_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[int] = InputField(
|
||||
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Float
|
||||
|
||||
|
||||
class FloatOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single float"""
|
||||
|
||||
type: Literal["float_output"] = "float_output"
|
||||
a: float = OutputField(description="The output float")
|
||||
|
||||
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
type: Literal["float_collection_output"] = "float_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[float] = OutputField(
|
||||
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Float Primitive")
|
||||
@tags("primitives", "float")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
type: Literal["float"] = "float"
|
||||
|
||||
# Inputs
|
||||
param: float = InputField(default=0.0, description="The float value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(a=self.param)
|
||||
|
||||
|
||||
@title("Float Primitive Collection")
|
||||
@tags("primitives", "float", "collection")
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
type: Literal["float_collection"] = "float_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[float] = InputField(
|
||||
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region String
|
||||
|
||||
|
||||
class StringOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single string"""
|
||||
|
||||
type: Literal["string_output"] = "string_output"
|
||||
text: str = OutputField(description="The output string")
|
||||
|
||||
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
type: Literal["string_collection_output"] = "string_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[str] = OutputField(
|
||||
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
|
||||
@title("String Primitive")
|
||||
@tags("primitives", "string")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
type: Literal["string"] = "string"
|
||||
|
||||
# Inputs
|
||||
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(text=self.text)
|
||||
|
||||
|
||||
@title("String Primitive Collection")
|
||||
@tags("primitives", "string", "collection")
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
type: Literal["string_collection"] = "string_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[str] = InputField(
|
||||
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = OutputField(description="The output image")
|
||||
width: int = OutputField(description="The width of the image in pixels")
|
||||
height: int = OutputField(description="The height of the image in pixels")
|
||||
|
||||
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
type: Literal["image_collection_output"] = "image_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ImageField] = OutputField(
|
||||
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Image Primitive")
|
||||
@tags("primitives", "image")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
# Metadata
|
||||
type: Literal["image"] = "image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
@title("Image Primitive Collection")
|
||||
@tags("primitives", "image", "collection")
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
|
||||
type: Literal["image_collection"] = "image_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ImageField] = InputField(
|
||||
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latents
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single latents tensor"""
|
||||
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
latents: LatentsField = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of latents tensors"""
|
||||
|
||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
default_factory=list,
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
|
||||
|
||||
@title("Latents Primitive")
|
||||
@tags("primitives", "latents")
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
|
||||
type: Literal["latents"] = "latents"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@title("Latents Primitive Collection")
|
||||
@tags("primitives", "latents", "collection")
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
|
||||
type: Literal["latents_collection"] = "latents_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[LatentsField] = InputField(
|
||||
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Color
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ColorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single color"""
|
||||
|
||||
type: Literal["color_output"] = "color_output"
|
||||
color: ColorField = OutputField(description="The output color")
|
||||
|
||||
|
||||
class ColorCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of colors"""
|
||||
|
||||
type: Literal["color_collection_output"] = "color_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ColorField] = OutputField(
|
||||
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
|
||||
)
|
||||
|
||||
|
||||
@title("Color Primitive")
|
||||
@tags("primitives", "color")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
type: Literal["color"] = "color"
|
||||
|
||||
# Inputs
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ColorOutput:
|
||||
return ColorOutput(color=self.color)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Conditioning
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
type: Literal["conditioning_output"] = "conditioning_output"
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||
|
||||
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
|
||||
|
||||
# Outputs
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
default_factory=list,
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
|
||||
@title("Conditioning Primitive")
|
||||
@tags("primitives", "conditioning")
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
type: Literal["conditioning"] = "conditioning"
|
||||
|
||||
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@title("Conditioning Primitive Collection")
|
||||
@tags("primitives", "conditioning", "collection")
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
||||
type: Literal["conditioning_collection"] = "conditioning_collection"
|
||||
|
||||
# Inputs
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
return ConditioningCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -1,62 +1,28 @@
|
||||
from os.path import exists
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, validator
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from pydantic import validator
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
|
||||
class PromptOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a prompt"""
|
||||
#fmt: off
|
||||
type: Literal["prompt"] = "prompt"
|
||||
|
||||
prompt: str = Field(default=None, description="The output prompt")
|
||||
#fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'prompt',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class PromptCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a collection of prompts"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["prompt_collection_output"] = "prompt_collection_output"
|
||||
|
||||
prompt_collection: list[str] = Field(description="The output prompt collection")
|
||||
count: int = Field(description="The size of the prompt collection")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "prompt_collection", "count"]}
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
|
||||
|
||||
|
||||
@title("Dynamic Prompt")
|
||||
@tags("prompt", "collection")
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = Field(
|
||||
default=False, description="Whether to use the combinatorial generator"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Dynamic Prompt",
|
||||
"tags": ["prompt", "dynamic"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||
@@ -64,29 +30,26 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@title("Prompts from File")
|
||||
@tags("prompt", "file")
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
'''Loads prompts from a text file'''
|
||||
# fmt: off
|
||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
type: Literal["prompt_from_file"] = "prompt_from_file"
|
||||
|
||||
# Inputs
|
||||
file_path: str = Field(description="Path to prompt text file")
|
||||
pre_prompt: Optional[str] = Field(description="String to prepend to each prompt")
|
||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
#fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompts From File",
|
||||
"tags": ["prompt", "file"]
|
||||
},
|
||||
}
|
||||
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
|
||||
pre_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||
)
|
||||
post_prompt: Optional[str] = InputField(
|
||||
default=None, description="String to append to each prompt", ui_component=UIComponent.Textarea
|
||||
)
|
||||
start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from")
|
||||
max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||
|
||||
@validator("file_path")
|
||||
def file_path_exists(cls, v):
|
||||
@@ -94,7 +57,14 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
raise ValueError(FileNotFoundError)
|
||||
return v
|
||||
|
||||
def promptsFromFile(self, file_path: str, pre_prompt: str, post_prompt: str, start_line: int, max_prompts: int):
|
||||
def promptsFromFile(
|
||||
self,
|
||||
file_path: str,
|
||||
pre_prompt: Union[str, None],
|
||||
post_prompt: Union[str, None],
|
||||
start_line: int,
|
||||
max_prompts: int,
|
||||
):
|
||||
prompts = []
|
||||
start_line -= 1
|
||||
end_line = start_line + max_prompts
|
||||
@@ -103,11 +73,13 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
with open(file_path) as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= start_line and i < end_line:
|
||||
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
|
||||
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
||||
if i >= end_line:
|
||||
break
|
||||
return prompts
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
|
||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||
)
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
@@ -1,59 +1,55 @@
|
||||
import torch
|
||||
import inspect
|
||||
from tqdm import tqdm
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, validator
|
||||
from typing import Literal
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
|
||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||
from .compel import ConditioningField
|
||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||
|
||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
# fmt: off
|
||||
|
||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
# fmt: on
|
||||
#fmt: on
|
||||
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@title("SDXL Main Model Loader")
|
||||
@tags("model", "sdxl")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Model Loader",
|
||||
"tags": ["model", "loader", "sdxl"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@@ -125,23 +121,22 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@title("SDXL Refiner Model Loader")
|
||||
@tags("model", "sdxl", "refiner")
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||
|
||||
model: MainModelField = Field(description="The model to load")
|
||||
# Inputs
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "refiner_model"},
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
@@ -196,514 +191,3 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l_sdxl"] = "t2l_sdxl"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Text To Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
num_inference_steps = self.steps
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(), context=context
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
|
||||
num_warmup_steps = len(timesteps) - self.steps * scheduler.order
|
||||
|
||||
# apply denoising_end
|
||||
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
if not context.services.configuration.sequential_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["l2l_sdxl"] = "l2l_sdxl"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "SDXL Latents to Latents",
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
) -> None:
|
||||
stable_diffusion_xl_step_callback(
|
||||
context=context,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
sample=sample,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
prompt_embeds = positive_cond_data.conditionings[0].embeds
|
||||
pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
|
||||
add_time_ids = positive_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
|
||||
negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
|
||||
add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
# apply denoising_start
|
||||
num_inference_steps = self.steps
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply noise(if provided)
|
||||
if self.noise is not None and timesteps.shape[0] > 0:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||
del noise
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with unet_info as unet:
|
||||
|
||||
# apply scheduler extra args
|
||||
extra_step_kwargs = dict()
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||
|
||||
# apply denoising_end
|
||||
skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
|
||||
num_inference_steps = num_inference_steps - skipped_final_steps
|
||||
timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]
|
||||
|
||||
if not context.services.configuration.sequential_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
#del noise_pred_uncond
|
||||
#del noise_pred_text
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
else:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
with tqdm(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
#del noise_pred_text
|
||||
#del noise_pred_uncond
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
#del noise_pred
|
||||
#import gc
|
||||
#gc.collect()
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||
#if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
|
||||
|
||||
#################
|
||||
|
||||
latents = latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -6,13 +6,12 @@ import cv2 as cv
|
||||
import numpy as np
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import Field
|
||||
from realesrgan import RealESRGANer
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@@ -24,22 +23,16 @@ ESRGAN_MODELS = Literal[
|
||||
]
|
||||
|
||||
|
||||
@title("Upscale (RealESRGAN)")
|
||||
@tags("esrgan", "upscale")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
type: Literal["esrgan"] = "esrgan"
|
||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||
model_name: ESRGAN_MODELS = Field(
|
||||
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Upscale (RealESRGAN)",
|
||||
"tags": ["image", "upscale", "realesrgan"]
|
||||
},
|
||||
}
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -108,9 +101,7 @@ class ESRGANInvocation(BaseInvocation):
|
||||
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||
|
||||
# back to PIL
|
||||
pil_image = Image.fromarray(
|
||||
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
||||
).convert("RGBA")
|
||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,30 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
@@ -34,47 +12,6 @@ class ProgressImage(BaseModel):
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
@@ -132,5 +69,3 @@ class InvalidImageCategoryException(ValueError):
|
||||
|
||||
def __init__(self, message="Invalid image category."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
@@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@@ -207,9 +205,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
try:
|
||||
|
||||
@@ -31,7 +31,6 @@ class BoardImagesServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
@@ -93,18 +92,15 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
self._services.board_image_records.remove_image_from_board(image_name)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> list[str]:
|
||||
return self._services.board_image_records.get_all_board_image_names_for_board(
|
||||
board_id
|
||||
)
|
||||
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
@@ -114,9 +110,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
|
||||
) -> BoardDTO:
|
||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={"cover_image_name"}),
|
||||
|
||||
@@ -15,9 +15,7 @@ from pydantic import BaseModel, Field, Extra
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's new cover image."
|
||||
)
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
@@ -292,9 +290,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](
|
||||
items=boards, offset=offset, limit=limit, total=count
|
||||
)
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
|
||||
@@ -108,16 +108,12 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
@@ -126,60 +122,44 @@ class BoardService(BoardServiceABC):
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](
|
||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||
)
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
||||
return board_dtos
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
'''Invokeai configuration system.
|
||||
"""Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
@@ -24,11 +24,10 @@ InvokeAI:
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_cache_size: 6
|
||||
max_vram_cache_size: 2.7
|
||||
max_vram_cache_size: 0.5
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
@@ -158,76 +157,85 @@ two configs are kept in separate sections of the config file:
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
'''
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import pydoc
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
MODEL_CORE = Path('models/core')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
"""
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt = parser.parse_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
value = list(value)
|
||||
elif isinstance(value, DictConfig):
|
||||
value = dict(value)
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self)->str:
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)['type'])[0]
|
||||
field_dict = dict({type:dict()})
|
||||
for name,field in self.__fields__.items():
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = dict({type: dict()})
|
||||
for name, field in self.__fields__.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self,name)
|
||||
value = getattr(self, name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if 'type' in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
|
||||
initconf = cls.initconf.get(settings_stanza) \
|
||||
if cls.initconf and settings_stanza in cls.initconf \
|
||||
else OmegaConf.create()
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key,value in os.environ.items():
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
@@ -237,8 +245,8 @@ class InvokeAISettings(BaseSettings):
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category","Uncategorized")
|
||||
env_name = env_prefix + '_' + name
|
||||
category = field.field_info.extra.get("category", "Uncategorized")
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
@@ -248,15 +256,15 @@ class InvokeAISettings(BaseSettings):
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'Uncategorized'
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
@@ -269,24 +277,37 @@ class InvokeAISettings(BaseSettings):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
def _excluded(self) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ['type','initconf']
|
||||
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self)->List[str]:
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root', 'nsfw_checker']
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
env_file_encoding = "utf-8"
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
@@ -315,10 +336,10 @@ class InvokeAISettings(BaseSettings):
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs='*',
|
||||
nargs="*",
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
@@ -327,31 +348,35 @@ class InvokeAISettings(BaseSettings):
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
|
||||
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
"""
|
||||
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
|
||||
#fmt: off
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
@@ -364,21 +389,17 @@ setting environment variables INVOKEAI_<setting>.
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
|
||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
|
||||
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
@@ -390,8 +411,7 @@ setting environment variables INVOKEAI_<setting>.
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
@@ -399,16 +419,19 @@ setting environment variables INVOKEAI_<setting>.
|
||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||
'''
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
'''
|
||||
"""
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
@@ -425,135 +448,141 @@ setting environment variables INVOKEAI_<setting>.
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
||||
|
||||
@classmethod
|
||||
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
'''
|
||||
if cls.singleton_config is None \
|
||||
or type(cls.singleton_config)!=cls \
|
||||
or (kwargs and cls.singleton_init != kwargs):
|
||||
"""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) != cls
|
||||
or (kwargs and cls.singleton_init != kwargs)
|
||||
):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
def root_path(self) -> Path:
|
||||
"""
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
"""
|
||||
if self.root:
|
||||
return Path(self.root).expanduser().absolute()
|
||||
root = Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
return self.find_root()
|
||||
root = self.find_root().expanduser().absolute()
|
||||
self.root = root # insulate ourselves from relative paths that may change
|
||||
return root
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
def root_dir(self) -> Path:
|
||||
"""
|
||||
Alias for above.
|
||||
'''
|
||||
"""
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
def _resolve(self, partial_path: Path) -> Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def init_file_path(self)->Path:
|
||||
'''
|
||||
def init_file_path(self) -> Path:
|
||||
"""
|
||||
Path to invokeai.yaml
|
||||
'''
|
||||
"""
|
||||
return self._resolve(INIT_FILE)
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
def output_path(self) -> Path:
|
||||
"""
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self)->Path:
|
||||
'''
|
||||
def db_path(self) -> Path:
|
||||
"""
|
||||
Path to the invokeai.db file.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
def model_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to models configuration file.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self)->Path:
|
||||
'''
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self)->Path:
|
||||
'''
|
||||
def models_path(self) -> Path:
|
||||
"""
|
||||
Path to the models directory
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
def autoconvert_path(self) -> Path:
|
||||
"""
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
"""
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
def full_precision(self) -> bool:
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision=='float32'
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def disable_xformers(self)->bool:
|
||||
def disable_xformers(self) -> bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self)->bool:
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true"""
|
||||
return self.patchmatch
|
||||
|
||||
@property
|
||||
def nsfw_checker(self)->bool:
|
||||
""" NSFW node is always active and disabled from Web UIe"""
|
||||
def nsfw_checker(self) -> bool:
|
||||
"""NSFW node is always active and disabled from Web UIe"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def invisible_watermark(self)->bool:
|
||||
""" invisible watermark node is always active and disabled from Web UIe"""
|
||||
def invisible_watermark(self) -> bool:
|
||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
def find_root() -> Path:
|
||||
"""
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
"""
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
'''
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||
'''
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
'''
|
||||
"""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
@@ -1,63 +1,86 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
|
||||
from ..invocations.latent import LatentsToImageInvocation, DenoiseLatentsInvocation
|
||||
from ..invocations.image import ImageNSFWBlurInvocation
|
||||
from ..invocations.noise import NoiseInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from ..invocations.primitives import IntegerInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
name="t2i",
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
'8': ImageNSFWBlurInvocation(id='8'),
|
||||
"width": IntegerInvocation(id="width", a=512),
|
||||
"height": IntegerInvocation(id="height", a=512),
|
||||
"seed": IntegerInvocation(id="seed", a=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')),
|
||||
]
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="a"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="a"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="a"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='8', field='image', alias='image')
|
||||
])
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
|
||||
@@ -35,6 +35,7 @@ class EventServiceBase:
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
order: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
@@ -44,10 +45,9 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict()
|
||||
if progress_image is not None
|
||||
else None,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
order=order,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
@@ -90,9 +90,7 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
) -> None:
|
||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
|
||||
@@ -3,16 +3,7 @@
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
@@ -22,12 +13,17 @@ from ..invocations import *
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
node_id: str = Field(description="The id of the node for this edge connection")
|
||||
field: str = Field(description="The field for this connection")
|
||||
@@ -61,6 +57,7 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
@@ -71,6 +68,7 @@ def is_union_subtype(t1, t2):
|
||||
# t1 is a Union, check that all of its types are in t2_args
|
||||
return all(arg in t2_args for arg in t1_args)
|
||||
|
||||
|
||||
def is_list_or_contains_list(t):
|
||||
t_args = get_args(t)
|
||||
|
||||
@@ -154,15 +152,17 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
"required": [
|
||||
"type",
|
||||
"image",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
@@ -178,27 +178,21 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
type: Literal["iterate_output"] = "iterate_output"
|
||||
|
||||
item: Any = Field(description="The item being iterated over")
|
||||
item: Any = OutputField(
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'item',
|
||||
]
|
||||
}
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(
|
||||
description="The list of items to iterate over", default_factory=list
|
||||
)
|
||||
index: int = Field(
|
||||
description="The index, will be provided on executed iterators", default=0
|
||||
collection: list[Any] = InputField(
|
||||
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||
)
|
||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
@@ -208,28 +202,24 @@ class IterateInvocation(BaseInvocation):
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
type: Literal["collect_output"] = "collect_output"
|
||||
|
||||
collection: list[Any] = Field(description="The collection of input items")
|
||||
collection: list[Any] = OutputField(
|
||||
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||
)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'collection',
|
||||
]
|
||||
}
|
||||
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
|
||||
type: Literal["collect"] = "collect"
|
||||
|
||||
item: Any = Field(
|
||||
item: Any = InputField(
|
||||
description="The item to collect (all inputs must be of the same type)",
|
||||
default=None,
|
||||
ui_type=UIType.CollectionItem,
|
||||
title="Collection Item",
|
||||
input=Input.Connection,
|
||||
)
|
||||
collection: list[Any] = Field(
|
||||
description="The collection, will be provided on execution",
|
||||
default_factory=list,
|
||||
collection: list[Any] = InputField(
|
||||
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
@@ -269,9 +259,7 @@ class Graph(BaseModel):
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = (
|
||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
)
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||
|
||||
@@ -333,9 +321,7 @@ class Graph(BaseModel):
|
||||
return False
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set(
|
||||
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||
)
|
||||
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
|
||||
@@ -361,22 +347,14 @@ class Graph(BaseModel):
|
||||
# Validate all iterators
|
||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||
if not all(
|
||||
(
|
||||
self._is_iterator_connection_valid(n.id)
|
||||
for n in self.nodes.values()
|
||||
if isinstance(n, IterateInvocation)
|
||||
)
|
||||
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate all collectors
|
||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||
if not all(
|
||||
(
|
||||
self._is_collector_connection_valid(n.id)
|
||||
for n in self.nodes.values()
|
||||
if isinstance(n, CollectInvocation)
|
||||
)
|
||||
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
|
||||
):
|
||||
return False
|
||||
|
||||
@@ -395,48 +373,51 @@ class Graph(BaseModel):
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
raise InvalidEdgeError(
|
||||
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
|
||||
)
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
raise InvalidEdgeError(
|
||||
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
|
||||
)
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
|
||||
raise InvalidEdgeError(
|
||||
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
|
||||
raise InvalidEdgeError(
|
||||
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||
raise InvalidEdgeError(
|
||||
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
|
||||
raise InvalidEdgeError(
|
||||
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||
raise InvalidEdgeError(
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@@ -465,17 +446,13 @@ class Graph(BaseModel):
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) != type(new_node):
|
||||
raise TypeError(
|
||||
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
|
||||
)
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError(
|
||||
"Node with id {new_node.id} already exists in graph"
|
||||
)
|
||||
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
||||
|
||||
# Set the new node in the graph
|
||||
graph.nodes[new_node.id] = new_node
|
||||
@@ -497,9 +474,7 @@ class Graph(BaseModel):
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.destination.field
|
||||
)
|
||||
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -512,16 +487,12 @@ class Graph(BaseModel):
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=new_graph_node_path, field=edge.source.field
|
||||
),
|
||||
destination=edge.destination
|
||||
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
||||
destination=edge.destination,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(
|
||||
self, node_path: str, field: Optional[str] = None
|
||||
) -> list[Edge]:
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
@@ -538,7 +509,7 @@ class Graph(BaseModel):
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
@@ -550,32 +521,20 @@ class Graph(BaseModel):
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||
)
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
|
||||
node_id = (
|
||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
)
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = (
|
||||
node.id
|
||||
if prefix is None or prefix == ""
|
||||
else self._get_node_path(node.id, prefix=prefix)
|
||||
)
|
||||
graph_edges = graph._get_input_edges_and_graphs(
|
||||
node_path[(len(node_id) + 1) :], prefix=graph_path
|
||||
)
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
def _get_output_edges(
|
||||
self, node_path: str, field: str
|
||||
) -> list[Edge]:
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
@@ -592,7 +551,7 @@ class Graph(BaseModel):
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
)
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
@@ -604,25 +563,15 @@ class Graph(BaseModel):
|
||||
edges = list()
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend(
|
||||
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||
)
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
|
||||
node_id = (
|
||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
)
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = (
|
||||
node.id
|
||||
if prefix is None or prefix == ""
|
||||
else self._get_node_path(node.id, prefix=prefix)
|
||||
)
|
||||
graph_edges = graph._get_output_edges_and_graphs(
|
||||
node_path[(len(node_id) + 1) :], prefix=graph_path
|
||||
)
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
@@ -646,12 +595,8 @@ class Graph(BaseModel):
|
||||
return False
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field = get_output_field(
|
||||
self.get_node(inputs[0].node_id), inputs[0].field
|
||||
)
|
||||
output_fields = list(
|
||||
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
)
|
||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
@@ -659,12 +604,7 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
input_field_item_type = get_args(input_field)[0]
|
||||
if not all(
|
||||
(
|
||||
are_connection_types_compatible(input_field_item_type, f)
|
||||
for f in output_fields
|
||||
)
|
||||
):
|
||||
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -684,35 +624,21 @@ class Graph(BaseModel):
|
||||
outputs.append(new_output)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_fields = list(
|
||||
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||
)
|
||||
output_fields = list(
|
||||
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
)
|
||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||
|
||||
# Validate that all inputs are derived from or match a single type
|
||||
input_field_types = set(
|
||||
[
|
||||
t
|
||||
for input_field in input_fields
|
||||
for t in (
|
||||
[input_field]
|
||||
if get_origin(input_field) == None
|
||||
else get_args(input_field)
|
||||
)
|
||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
||||
if t != NoneType
|
||||
]
|
||||
) # Get unique types
|
||||
type_tree = nx.DiGraph()
|
||||
type_tree.add_nodes_from(input_field_types)
|
||||
type_tree.add_edges_from(
|
||||
[
|
||||
e
|
||||
for e in itertools.permutations(input_field_types, 2)
|
||||
if issubclass(e[1], e[0])
|
||||
]
|
||||
)
|
||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||
return False # There is more than one root type
|
||||
@@ -729,9 +655,7 @@ class Graph(BaseModel):
|
||||
return False
|
||||
|
||||
# Verify that all outputs match the input type (are a base class or the same class)
|
||||
if not all(
|
||||
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
|
||||
):
|
||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -751,9 +675,7 @@ class Graph(BaseModel):
|
||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||
return g
|
||||
|
||||
def nx_graph_flat(
|
||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||
) -> nx.DiGraph:
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
@@ -762,26 +684,18 @@ class Graph(BaseModel):
|
||||
[
|
||||
self._get_node_path(n.id, prefix)
|
||||
for n in self.nodes.values()
|
||||
if not isinstance(n, GraphInvocation)
|
||||
and not isinstance(n, IterateInvocation)
|
||||
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
||||
]
|
||||
)
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||
g.add_edges_from(
|
||||
[
|
||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||
for e in unique_edges
|
||||
]
|
||||
)
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
@@ -800,23 +714,19 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# Nodes that have been executed
|
||||
executed: set[str] = Field(
|
||||
description="The set of node ids that have been executed", default_factory=set
|
||||
)
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[
|
||||
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The results of node executions", default_factory=dict
|
||||
)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(
|
||||
description="Errors raised when executing nodes", default_factory=dict
|
||||
)
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
|
||||
# Map of prepared/executed nodes to their original nodes
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
@@ -832,16 +742,16 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'id',
|
||||
'graph',
|
||||
'execution_graph',
|
||||
'executed',
|
||||
'executed_history',
|
||||
'results',
|
||||
'errors',
|
||||
'prepared_source_mapping',
|
||||
'source_prepared_mapping',
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
"executed",
|
||||
"executed_history",
|
||||
"results",
|
||||
"errors",
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
|
||||
@@ -899,9 +809,7 @@ class GraphExecutionState(BaseModel):
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(
|
||||
self, node_path: str, iteration_node_map: list[tuple[str, str]]
|
||||
) -> list[str]:
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_path)
|
||||
@@ -911,20 +819,12 @@ class GraphExecutionState(BaseModel):
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(
|
||||
iter(self.graph._get_input_edges(node_path, "collection"))
|
||||
)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1]
|
||||
for n in iteration_node_map
|
||||
if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[
|
||||
input_collection_prepared_node_id
|
||||
]
|
||||
input_collection = getattr(
|
||||
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||
self_iteration_count = len(input_collection)
|
||||
|
||||
new_nodes = list()
|
||||
@@ -939,9 +839,7 @@ class GraphExecutionState(BaseModel):
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
new_edges = list()
|
||||
for edge in input_edges:
|
||||
for input_node_id in (
|
||||
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||
):
|
||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||
new_edge = Edge(
|
||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||
@@ -982,11 +880,7 @@ class GraphExecutionState(BaseModel):
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
return g
|
||||
@@ -994,11 +888,7 @@ class GraphExecutionState(BaseModel):
|
||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||
"""Gets iterators for a node"""
|
||||
g = self._iterator_graph()
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
||||
return iterators
|
||||
|
||||
def _prepare(self) -> Optional[str]:
|
||||
@@ -1045,29 +935,18 @@ class GraphExecutionState(BaseModel):
|
||||
if isinstance(next_node, CollectInvocation):
|
||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||
all_iteration_mappings = list(
|
||||
itertools.chain(
|
||||
*(
|
||||
((s, p) for p in self.source_prepared_mapping[s])
|
||||
for s in next_node_parents
|
||||
)
|
||||
)
|
||||
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
|
||||
)
|
||||
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||
create_results = self._create_execution_node(
|
||||
next_node_id, all_iteration_mappings
|
||||
)
|
||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||
if create_results is not None:
|
||||
new_node_ids.extend(create_results)
|
||||
else: # Iterators or normal nodes
|
||||
# Get all iterator combinations for this node
|
||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||
iterator_nodes_prepared = [
|
||||
list(self.source_prepared_mapping[n]) for n in iterator_nodes
|
||||
]
|
||||
iterator_node_prepared_combinations = list(
|
||||
itertools.product(*iterator_nodes_prepared)
|
||||
)
|
||||
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
||||
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
||||
|
||||
# Select the correct prepared parents for each iteration
|
||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||
@@ -1096,31 +975,16 @@ class GraphExecutionState(BaseModel):
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
# Check if the requested node is an iterator
|
||||
prepared_iterator = next(
|
||||
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
|
||||
)
|
||||
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
||||
if prepared_iterator is not None:
|
||||
return prepared_iterator
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [
|
||||
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes
|
||||
]
|
||||
parent_iterators = [
|
||||
itn
|
||||
for itn in iterator_source_node_mapping
|
||||
if nx.has_path(graph, itn[1], source_node_path)
|
||||
]
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
|
||||
return next(
|
||||
(
|
||||
n
|
||||
for n in prepared_nodes
|
||||
if all(
|
||||
nx.has_path(execution_graph, pit[0], n)
|
||||
for pit in parent_iterators
|
||||
)
|
||||
),
|
||||
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -1130,13 +994,13 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -1221,15 +1085,18 @@ class ExposedNodeOutput(BaseModel):
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
@validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
@@ -1237,23 +1104,27 @@ class LibraryGraph(BaseModel):
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
graph = values["graph"]
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
for exposed_input in values["exposed_inputs"]:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
raise ValueError(
|
||||
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
|
||||
)
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
for exposed_output in values["exposed_outputs"]:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
raise ValueError(
|
||||
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@@ -85,9 +85,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
|
||||
self.__output_folder: Path = (
|
||||
output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
)
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||
|
||||
# Validate required output folders at launch
|
||||
@@ -120,7 +118,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||
if graph is not None:
|
||||
@@ -183,9 +181,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
if not image_name in self.__cache:
|
||||
self.__cache[image_name] = image
|
||||
self.__cache_ids.put(
|
||||
image_name
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
|
||||
@@ -67,6 +67,7 @@ IMAGE_DTO_COLS = ", ".join(
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"deleted_at",
|
||||
"starred",
|
||||
],
|
||||
)
|
||||
)
|
||||
@@ -139,6 +140,7 @@ class ImageRecordStorageBase(ABC):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
@@ -200,6 +202,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -222,6 +234,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -321,6 +339,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@@ -397,7 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
@@ -426,9 +455,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
@@ -466,7 +493,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -503,11 +529,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[dict],
|
||||
is_intermediate: bool = False,
|
||||
starred: bool = False,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
None if metadata is None else json.dumps(metadata)
|
||||
)
|
||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@@ -520,9 +545,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate
|
||||
is_intermediate,
|
||||
starred
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@@ -534,6 +560,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
starred,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@@ -217,12 +217,8 @@ class ImageService(ImageServiceABC):
|
||||
session_id=session_id,
|
||||
)
|
||||
if board_id is not None:
|
||||
self._services.board_image_records.add_image_to_board(
|
||||
board_id=board_id, image_name=image_name
|
||||
)
|
||||
self._services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, graph=graph
|
||||
)
|
||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
return image_dto
|
||||
@@ -293,13 +289,12 @@ class ImageService(ImageServiceABC):
|
||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||
try:
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
|
||||
if not image_record.session_id:
|
||||
return ImageMetadata()
|
||||
return ImageMetadata(metadata=metadata)
|
||||
|
||||
session_raw = self._services.graph_execution_manager.get_raw(
|
||||
image_record.session_id
|
||||
)
|
||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||
graph = None
|
||||
|
||||
if session_raw:
|
||||
@@ -309,7 +304,6 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||
graph = None
|
||||
|
||||
metadata = self._services.image_records.get_metadata(image_name)
|
||||
return ImageMetadata(graph=graph, metadata=metadata)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
@@ -364,9 +358,7 @@ class ImageService(ImageServiceABC):
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(
|
||||
r.image_name
|
||||
),
|
||||
self._services.board_image_records.get_board_for_image(r.image_name),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@@ -398,11 +390,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
try:
|
||||
image_names = (
|
||||
self._services.board_image_records.get_all_board_image_names_for_board(
|
||||
board_id
|
||||
)
|
||||
)
|
||||
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
|
||||
@@ -7,6 +7,7 @@ from queue import Queue
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
@@ -45,9 +46,11 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
||||
while isinstance(item, InvocationQueueItem) \
|
||||
and item.graph_execution_state_id in self.__cancellations \
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||
while (
|
||||
isinstance(item, InvocationQueueItem)
|
||||
and item.graph_execution_state_id in self.__cancellations
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
|
||||
):
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
|
||||
@@ -32,6 +32,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
|
||||
def __init__(
|
||||
@@ -47,6 +48,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.board_images = board_images
|
||||
@@ -61,4 +63,5 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
305
invokeai/app/services/invocation_stats.py
Normal file
305
invokeai/app/services/invocation_stats.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||
|
||||
"""
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
|
||||
Typical output:
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||
|
||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
:param graph_execution_manager: Graph execution manager for this session
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
"""
|
||||
Reset all statistics for the indicated graph
|
||||
:param graph_execution_state_id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self):
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStats:
|
||||
"""Class for tracking execution stats of an invocation node"""
|
||||
|
||||
calls: int = 0
|
||||
time_used: float = 0.0 # seconds
|
||||
max_vram: float = 0.0 # GB
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cache_high_watermark: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeLog:
|
||||
"""Class for tracking node usage"""
|
||||
|
||||
# {node_type => NodeStats}
|
||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||
|
||||
|
||||
class InvocationStatsService(InvocationStatsServiceBase):
|
||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||
|
||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
# {graph_id => NodeLog}
|
||||
self._stats: Dict[str, NodeLog] = {}
|
||||
self._cache_stats: Dict[str, CacheStats] = {}
|
||||
self.ram_used: float = 0.0
|
||||
self.ram_changed: float = 0.0
|
||||
|
||||
class StatsContext:
|
||||
"""Context manager for collecting statistics."""
|
||||
|
||||
invocation: BaseInvocation = None
|
||||
collector: "InvocationStatsServiceBase" = None
|
||||
graph_id: str = None
|
||||
start_time: int = 0
|
||||
ram_used: int = 0
|
||||
model_manager: ModelManagerService = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
self.invocation = invocation
|
||||
self.collector = collector
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.ram_used = psutil.Process().memory_info().rss
|
||||
if self.model_manager:
|
||||
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
|
||||
def __exit__(self, *args):
|
||||
"""Called on exit from the context."""
|
||||
ram_used = psutil.Process().memory_info().rss
|
||||
self.collector.update_mem_stats(
|
||||
ram_used=ram_used / GIG,
|
||||
ram_changed=(ram_used - self.ram_used) / GIG,
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type,
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
) -> StatsContext:
|
||||
"""
|
||||
Return a context object that will capture the statistics.
|
||||
:param invocation: BaseInvocation object from the current graph.
|
||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||
"""
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
self._stats = {}
|
||||
|
||||
def reset_stats(self, graph_execution_id: str):
|
||||
"""Zero the statistics for the indicated graph."""
|
||||
try:
|
||||
self._stats.pop(graph_execution_id)
|
||||
except KeyError:
|
||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||
|
||||
def update_mem_stats(
|
||||
self,
|
||||
ram_used: float,
|
||||
ram_changed: float,
|
||||
):
|
||||
"""
|
||||
Update the collector with RAM memory usage info.
|
||||
|
||||
:param ram_used: How much RAM is currently in use.
|
||||
:param ram_changed: How much RAM changed since last generation.
|
||||
"""
|
||||
self.ram_used = ram_used
|
||||
self.ram_changed = ram_changed
|
||||
|
||||
def update_invocation_stats(
|
||||
self,
|
||||
graph_id: str,
|
||||
invocation_type: str,
|
||||
time_used: float,
|
||||
vram_used: float,
|
||||
):
|
||||
"""
|
||||
Add timing information on execution of a node. Usually
|
||||
used internally.
|
||||
:param graph_id: ID of the graph that is currently executing
|
||||
:param invocation_type: String literal type of the node
|
||||
:param time_used: Time used by node's exection (sec)
|
||||
:param vram_used: Maximum VRAM used during exection (GB)
|
||||
:param ram_used: Current RAM available (GB)
|
||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
||||
"""
|
||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||
stats = self._stats[graph_id].nodes[invocation_type]
|
||||
stats.calls += 1
|
||||
stats.time_used += time_used
|
||||
stats.max_vram = max(stats.max_vram, vram_used)
|
||||
|
||||
def log_stats(self):
|
||||
"""
|
||||
Send the statistics to the system logger at the info level.
|
||||
Stats will only be printed when the execution of the graph
|
||||
is complete.
|
||||
"""
|
||||
completed = set()
|
||||
for graph_id, node_log in self._stats.items():
|
||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||
if not current_graph_state.is_complete():
|
||||
continue
|
||||
|
||||
total_time = 0
|
||||
logger.info(f"Graph stats: {graph_id}")
|
||||
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
|
||||
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
|
||||
total_time += stats.time_used
|
||||
|
||||
cache_stats = self._cache_stats[graph_id]
|
||||
hwm = cache_stats.high_watermark / GIG
|
||||
tot = cache_stats.cache_size / GIG
|
||||
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
|
||||
|
||||
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
|
||||
logger.info(f"RAM used to load models: {loaded:4.2f}G")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
|
||||
logger.info("RAM cache statistics:")
|
||||
logger.info(f" Model cache hits: {cache_stats.hits}")
|
||||
logger.info(f" Model cache misses: {cache_stats.misses}")
|
||||
logger.info(f" Models cached: {cache_stats.in_cache}")
|
||||
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
|
||||
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
|
||||
|
||||
completed.add(graph_id)
|
||||
|
||||
for graph_id in completed:
|
||||
del self._stats[graph_id]
|
||||
del self._cache_stats[graph_id]
|
||||
@@ -7,6 +7,7 @@ from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
|
||||
@@ -16,9 +17,7 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> Optional[str]:
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
|
||||
@@ -9,13 +9,15 @@ T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class PaginatedResults(GenericModel, Generic[T]):
|
||||
"""Paginated results"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ItemStorageABC(ABC, Generic[T]):
|
||||
_on_changed_callbacks: list[Callable[[T], None]]
|
||||
@@ -48,9 +50,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Union, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
@@ -25,7 +26,7 @@ class LatentsStorageBase(ABC):
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
@@ -87,8 +88,6 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from pydantic import Field
|
||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management import (
|
||||
@@ -21,6 +22,7 @@ from invokeai.backend.model_management import (
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@@ -103,7 +105,7 @@ class ModelManagerServiceBase(ABC):
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
@@ -125,7 +127,7 @@ class ModelManagerServiceBase(ABC):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
@@ -148,12 +150,12 @@ class ModelManagerServiceBase(ABC):
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
@@ -169,21 +171,20 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(
|
||||
self
|
||||
)->List[Path]:
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
@@ -194,7 +195,7 @@ class ModelManagerServiceBase(ABC):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@@ -211,11 +212,12 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
@@ -230,19 +232,23 @@ class ModelManagerServiceBase(ABC):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@@ -250,27 +256,34 @@ class ModelManagerServiceBase(ABC):
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@@ -280,13 +293,15 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@@ -298,17 +313,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
|
||||
logger.debug(f'Config file={config_file}')
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else ''
|
||||
logger.info(f'GPU device = {device} {device_name}')
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == 'float32' else torch.float16
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
@@ -316,9 +331,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `max_cache_size` config setting
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
@@ -332,7 +345,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info('Model manager service initialized')
|
||||
logger.info("Model manager service initialized")
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
@@ -371,7 +384,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
@@ -392,7 +405,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
@@ -405,22 +418,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type)
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
@@ -429,7 +438,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
)->None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@@ -437,7 +446,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'add/update model {model_name}')
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
def update_model(
|
||||
@@ -450,15 +459,15 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'update model {model_name}')
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -470,7 +479,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
@@ -478,8 +487,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@@ -494,10 +505,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@@ -524,7 +541,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
@@ -535,16 +552,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
@@ -559,18 +576,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@@ -578,25 +601,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names = model_names,
|
||||
base_model = base_model,
|
||||
merged_model_name = merged_model_name,
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
@@ -605,28 +628,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self)->List[Path]:
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
@@ -635,10 +659,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
)
|
||||
|
||||
8
invokeai/app/services/models/board_image.py
Normal file
8
invokeai/app/services/models/board_image.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardImage(BaseModelExcludeNull):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
@@ -1,40 +1,31 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Field
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
class BoardRecord(BaseModelExcludeNull):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the board."
|
||||
)
|
||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the cover image of the board."
|
||||
)
|
||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's cover image."
|
||||
)
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class ImageRecord(BaseModel):
|
||||
class ImageRecord(BaseModelExcludeNull):
|
||||
"""Deserialized image record without metadata."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
@@ -20,17 +21,11 @@ class ImageRecord(BaseModel):
|
||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||
height: int = Field(description="The height of the image in px.")
|
||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||
created_at: Union[datetime.datetime, str] = Field(
|
||||
description="The created timestamp of the image."
|
||||
)
|
||||
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.")
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime.datetime, str] = Field(
|
||||
description="The updated timestamp of the image."
|
||||
)
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the image."
|
||||
)
|
||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
@@ -44,33 +39,34 @@ class ImageRecord(BaseModel):
|
||||
description="The node ID that generated this image, if it is a generated image.",
|
||||
)
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
starred: bool = Field(description="Whether this image is starred.")
|
||||
"""Whether this image is starred."""
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||
"""A set of changes to apply to an image record.
|
||||
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
- `starred`: change whether the image is starred
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
description="The image's new category."
|
||||
)
|
||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
||||
"""The image's new category."""
|
||||
session_id: Optional[StrictStr] = Field(
|
||||
default=None,
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
starred: Optional[StrictBool] = Field(default=None, description="The image's new `starred` state")
|
||||
"""The image's new `starred` state."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||
"""The URLs for an image and its thumbnail."""
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
@@ -84,15 +80,17 @@ class ImageUrlsDTO(BaseModel):
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Optional[str] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
|
||||
image_record: ImageRecord,
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
@@ -110,12 +108,8 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
|
||||
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
||||
image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value))
|
||||
width = image_dict.get("width", 0)
|
||||
height = image_dict.get("height", 0)
|
||||
session_id = image_dict.get("session_id", None)
|
||||
@@ -124,6 +118,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
starred = image_dict.get("starred", False)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
@@ -137,4 +132,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
starred=starred,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ..models.exceptions import CanceledException
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_stats import InvocationStatsServiceBase
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
@@ -24,9 +27,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
target=self.__process,
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
@@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
@@ -47,10 +50,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
continue
|
||||
|
||||
try:
|
||||
graph_execution_state = (
|
||||
self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
@@ -60,11 +61,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
invocation = graph_execution_state.execution_graph.get_node(
|
||||
queue_item.invocation_id
|
||||
)
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
@@ -82,44 +81,48 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
outputs = invocation.invoke(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
statistics.log_stats()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
@@ -130,9 +133,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(
|
||||
graph_execution_state
|
||||
)
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
@@ -143,13 +144,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
@@ -164,12 +163,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc()
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from threading import Lock
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
from typing import Generic, Optional, TypeVar, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
@@ -49,7 +50,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
@@ -66,9 +68,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -81,9 +81,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def get_raw(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -96,15 +94,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@@ -113,7 +109,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
items = [json.loads(r[0]) for r in result]
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
@@ -122,13 +118,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[dict]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@@ -137,7 +129,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
items = [json.loads(r[0]) for r in result]
|
||||
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
@@ -149,6 +141,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
return PaginatedResults[dict](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
@@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_basename}/full"
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
@@ -17,16 +17,8 @@ from controlnet_aux.util import HWC3, resize_image
|
||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||
|
||||
lvmin_kernels_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[0, -1, -1],
|
||||
[1, 1, -1],
|
||||
[0, 1, 0]
|
||||
], dtype=np.int32)
|
||||
np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32),
|
||||
np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32),
|
||||
]
|
||||
|
||||
lvmin_kernels = []
|
||||
@@ -36,16 +28,8 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||
|
||||
lvmin_prunings_raw = [
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[0, 0, -1]
|
||||
], dtype=np.int32),
|
||||
np.array([
|
||||
[-1, -1, -1],
|
||||
[-1, 1, -1],
|
||||
[-1, 0, 0]
|
||||
], dtype=np.int32)
|
||||
np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32),
|
||||
np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32),
|
||||
]
|
||||
|
||||
lvmin_prunings = []
|
||||
@@ -99,10 +83,10 @@ def nake_nms(x):
|
||||
################################################################################
|
||||
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||
def pixel_perfect_resolution(
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: str,
|
||||
image: np.ndarray,
|
||||
target_H: int,
|
||||
target_W: int,
|
||||
resize_mode: str,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||
@@ -135,7 +119,7 @@ def pixel_perfect_resolution(
|
||||
|
||||
if resize_mode == "fill_resize":
|
||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||
|
||||
# print(f"Pixel Perfect Computation:")
|
||||
@@ -154,13 +138,7 @@ def pixel_perfect_resolution(
|
||||
# modified for InvokeAI
|
||||
###########################################################################
|
||||
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||
def np_img_resize(
|
||||
np_img: np.ndarray,
|
||||
resize_mode: str,
|
||||
h: int,
|
||||
w: int,
|
||||
device: torch.device = torch.device('cpu')
|
||||
):
|
||||
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
|
||||
# if 'inpaint' in module:
|
||||
# np_img = np_img.astype(np.float32)
|
||||
# else:
|
||||
@@ -184,15 +162,14 @@ def np_img_resize(
|
||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||
y = torch.from_numpy(y)
|
||||
y = y.float() / 255.0
|
||||
y = rearrange(y, 'h w c -> 1 c h w')
|
||||
y = rearrange(y, "h w c -> 1 c h w")
|
||||
y = y.clone()
|
||||
# y = y.to(devices.get_device_for("controlnet"))
|
||||
y = y.to(device)
|
||||
y = y.clone()
|
||||
return y
|
||||
|
||||
def high_quality_resize(x: np.ndarray,
|
||||
size):
|
||||
def high_quality_resize(x: np.ndarray, size):
|
||||
# Written by lvmin
|
||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||
inpaint_mask = None
|
||||
@@ -244,7 +221,7 @@ def np_img_resize(
|
||||
return y
|
||||
|
||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
if resize_mode == "just_resize": # RESIZE
|
||||
np_img = high_quality_resize(np_img, (w, h))
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
@@ -270,20 +247,21 @@ def np_img_resize(
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (h - new_h) // 2)
|
||||
pad_w = max(0, (w - new_w) // 2)
|
||||
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
||||
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
|
||||
np_img = high_quality_background
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||
k = max(k0, k1)
|
||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = np_img.shape
|
||||
pad_h = max(0, (new_h - h) // 2)
|
||||
pad_w = max(0, (new_w - w) // 2)
|
||||
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
|
||||
np_img = safe_numpy(np_img)
|
||||
return get_pytorch_control(np_img), np_img
|
||||
|
||||
|
||||
def prepare_control_image(
|
||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||
@@ -301,15 +279,17 @@ def prepare_control_image(
|
||||
resize_mode="just_resize_simple",
|
||||
):
|
||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||
if (resize_mode == "just_resize_simple" or
|
||||
resize_mode == "crop_resize_simple" or
|
||||
resize_mode == "fill_resize_simple"):
|
||||
if (
|
||||
resize_mode == "just_resize_simple"
|
||||
or resize_mode == "crop_resize_simple"
|
||||
or resize_mode == "fill_resize_simple"
|
||||
):
|
||||
image = image.convert("RGB")
|
||||
if (resize_mode == "just_resize_simple"):
|
||||
if resize_mode == "just_resize_simple":
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||
elif resize_mode == "crop_resize_simple": # not yet implemented
|
||||
pass
|
||||
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||
elif resize_mode == "fill_resize_simple": # not yet implemented
|
||||
pass
|
||||
nimage = np.array(image)
|
||||
nimage = nimage[None, :]
|
||||
@@ -320,7 +300,7 @@ def prepare_control_image(
|
||||
timage = torch.from_numpy(nimage)
|
||||
|
||||
# use fancy lvmin controlnet resizing
|
||||
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||
elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize":
|
||||
nimage = np.array(image)
|
||||
timage, nimage = np_img_resize(
|
||||
np_img=nimage,
|
||||
@@ -336,7 +316,7 @@ def prepare_control_image(
|
||||
exit(1)
|
||||
|
||||
timage = timage.to(device=device, dtype=dtype)
|
||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
timage = torch.cat([timage] * 2)
|
||||
return timage
|
||||
|
||||
@@ -18,5 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
rng = np.random.default_rng(seed=0)
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
23
invokeai/app/util/model_exclude_null.py
Normal file
23
invokeai/app/util/model_exclude_null.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
"""
|
||||
We want to exclude null values from objects that make their way to the client.
|
||||
|
||||
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
|
||||
dict method to do this.
|
||||
|
||||
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
|
||||
"""
|
||||
|
||||
|
||||
class BaseModelExcludeNull(BaseModel):
|
||||
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Override the default dict method to exclude None values in the response
|
||||
"""
|
||||
kwargs.pop("exclude_none", None)
|
||||
return super().dict(*args, exclude_none=True, **kwargs)
|
||||
|
||||
pass
|
||||
@@ -4,24 +4,21 @@ from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
|
||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
|
||||
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
@@ -32,6 +29,7 @@ def stable_diffusion_step_callback(
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
@@ -59,23 +57,50 @@ def stable_diffusion_step_callback(
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
@@ -89,58 +114,6 @@ def stable_diffusion_step_callback(
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
order=intermediate_state.order,
|
||||
total_steps=intermediate_state.total_steps,
|
||||
)
|
||||
|
||||
def stable_diffusion_xl_step_callback(
|
||||
context: InvocationContext,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
sample,
|
||||
step,
|
||||
total_steps,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[ 0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[ 0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
#[ 0.0478, 0.1285, 0.0478],
|
||||
#[ 0.1285, 0.2948, 0.1285],
|
||||
#[ 0.0478, 0.1285, 0.0478],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
@@ -1,15 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import (
|
||||
ModelManager, ModelCache, BaseModelType,
|
||||
ModelType, SubModelType, ModelInfo
|
||||
)
|
||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from .model_management.models import SilenceWarnings
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
Initialization file for the invokeai.generator package
|
||||
"""
|
||||
from .base import (
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
Img2Img,
|
||||
Inpaint,
|
||||
Generator,
|
||||
)
|
||||
from .inpaint import infill_methods
|
||||
@@ -1,580 +0,0 @@
|
||||
"""
|
||||
Base class for invokeai.backend.generator.*
|
||||
including img2img, txt2img, and inpaint
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import dataclasses
|
||||
import diffusers
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from abc import ABCMeta
|
||||
from argparse import Namespace
|
||||
from contextlib import nullcontext
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import Callable, List, Iterator, Optional, Type, Union
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
seed: Optional[int]=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: float=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
scheduler: str='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
threshold: float=0.0
|
||||
seamless: bool=False
|
||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||
h_symmetry_time_pct: Optional[float]=None
|
||||
v_symmetry_time_pct: Optional[float]=None
|
||||
variation_amount: float = 0.0
|
||||
with_variations: list=field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorOutput:
|
||||
'''
|
||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||
operation, including the image, its seed, the model name used to generate the image
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
generating the image (in .params, also available as attributes)
|
||||
'''
|
||||
image: Image.Image
|
||||
seed: int
|
||||
model_hash: str
|
||||
attention_maps_images: List[Image.Image]
|
||||
params: Namespace
|
||||
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
**kwargs,
|
||||
):
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
conditioning: tuple,
|
||||
scheduler,
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
**keyword_args,
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
'''
|
||||
Return an iterator across the indicated number of generations.
|
||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||
object. Use like this:
|
||||
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=5)
|
||||
for result in outputs:
|
||||
print(result.image, result.seed)
|
||||
|
||||
In the typical case of wanting to get just a single image, iterations
|
||||
defaults to 1 and do:
|
||||
|
||||
output = next(txt2img.generate(prompt='banana sushi')
|
||||
|
||||
Pass None to get an infinite iterator.
|
||||
|
||||
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
|
||||
for o in outputs:
|
||||
print(o.image, o.seed)
|
||||
|
||||
'''
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
model_info = self.model_info
|
||||
model_name = model_info.name
|
||||
model_hash = model_info.hash
|
||||
with model_info.context as model:
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
generator_args.get('seamless',False),
|
||||
generator_args.get('seamless_axes')
|
||||
)
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(
|
||||
conditioning=conditioning,
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
def schedulers(self)->List[str]:
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
In derived classes return the name of the generator to apply.
|
||||
If you don't override will return the name of the derived
|
||||
class, which nicely parallels the generator class names.
|
||||
'''
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
init_image: Union[Image.Image, torch.FloatTensor],
|
||||
strength: float=0.75,
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(init_image=init_image,
|
||||
strength=strength,
|
||||
**keyword_args
|
||||
)
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .img2img import Img2Img
|
||||
return Img2Img
|
||||
|
||||
# ------------------------------------
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(
|
||||
mask_image=mask_image,
|
||||
seam_size=seam_size,
|
||||
seam_blur=seam_blur,
|
||||
seam_strength=seam_strength,
|
||||
seam_steps=seam_steps,
|
||||
tile_size=tile_size,
|
||||
inpaint_replace=inpaint_replace,
|
||||
infill_method=infill_method,
|
||||
inpaint_width=inpaint_width,
|
||||
inpaint_height=inpaint_height,
|
||||
inpaint_fill=inpaint_fill,
|
||||
**keyword_args
|
||||
)
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
precision: str
|
||||
model: DiffusionPipeline
|
||||
|
||||
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.unet.config.in_channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"image_iterator() must be implemented in a descendent class"
|
||||
)
|
||||
|
||||
def set_variation(self, seed, variation_amount, with_variations):
|
||||
self.seed = seed
|
||||
self.variation_amount = variation_amount
|
||||
self.with_variations = with_variations
|
||||
|
||||
def generate(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
init_image=None,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
free_gpu_mem: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
scope = nullcontext
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
height=height,
|
||||
step_callback=step_callback,
|
||||
threshold=threshold,
|
||||
perlin=perlin,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
attention_maps_callback=attention_maps_callback,
|
||||
**kwargs,
|
||||
)
|
||||
results = []
|
||||
seed = seed if seed is not None and seed >= 0 else self.new_seed()
|
||||
first_seed = seed
|
||||
seed, initial_noise = self.generate_initial_noise(seed, width, height)
|
||||
|
||||
# There used to be an additional self.model.ema_scope() here, but it breaks
|
||||
# the inpaint-1.5 model. Not sure what it did.... ?
|
||||
with scope(self.model.device.type):
|
||||
for n in trange(iterations, desc="Generating"):
|
||||
x_T = None
|
||||
if self.variation_amount > 0:
|
||||
set_seed(seed)
|
||||
target_noise = self.get_noise(width, height)
|
||||
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
|
||||
elif initial_noise is not None:
|
||||
# i.e. we specified particular variations
|
||||
x_T = initial_noise
|
||||
else:
|
||||
set_seed(seed)
|
||||
try:
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
logger.error("An error occurred while getting initial noise")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
image = make_image(x_T, seed)
|
||||
|
||||
results.append([image, seed, attention_maps_images])
|
||||
|
||||
if image_callback is not None:
|
||||
attention_maps_image = (
|
||||
None
|
||||
if len(attention_maps_images) == 0
|
||||
else attention_maps_images[-1]
|
||||
)
|
||||
image_callback(
|
||||
image,
|
||||
seed,
|
||||
first_seed=first_seed,
|
||||
attention_maps_image=attention_maps_image,
|
||||
)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self, samples) -> Image.Image:
|
||||
"""
|
||||
Given samples returned from a sampler, converts
|
||||
it into a PIL Image
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
image = self.model.decode_latents(samples)
|
||||
return self.model.numpy_to_pil(image)[0]
|
||||
|
||||
def repaste_and_color_correct(
|
||||
self,
|
||||
result: Image.Image,
|
||||
init_image: Image.Image,
|
||||
init_mask: Image.Image,
|
||||
mask_blur_radius: int = 8,
|
||||
) -> Image.Image:
|
||||
if init_image is None or init_mask is None:
|
||||
return result
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = (
|
||||
init_mask.getchannel("A")
|
||||
if init_mask.mode == "RGBA"
|
||||
else init_mask.convert("L")
|
||||
)
|
||||
pil_init_image = init_image.convert(
|
||||
"RGBA"
|
||||
) # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel("A"), dtype=np.uint8)
|
||||
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
|
||||
# Get numpy version of result
|
||||
np_image = np.asarray(result, dtype=np.uint8)
|
||||
|
||||
# Mask and calculate mean and standard deviation
|
||||
mask_pixels = init_a_pixels * init_mask_pixels > 0
|
||||
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
|
||||
np_image_masked = np_image[mask_pixels, :]
|
||||
|
||||
if np_init_rgb_pixels_masked.size > 0:
|
||||
init_means = np_init_rgb_pixels_masked.mean(axis=0)
|
||||
init_std = np_init_rgb_pixels_masked.std(axis=0)
|
||||
gen_means = np_image_masked.mean(axis=0)
|
||||
gen_std = np_image_masked.std(axis=0)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = np_image.copy()
|
||||
np_matched_result[:, :, :] = (
|
||||
(
|
||||
(
|
||||
(
|
||||
np_matched_result[:, :, :].astype(np.float32)
|
||||
- gen_means[None, None, :]
|
||||
)
|
||||
/ gen_std[None, None, :]
|
||||
)
|
||||
* init_std[None, None, :]
|
||||
+ init_means[None, None, :]
|
||||
)
|
||||
.clip(0, 255)
|
||||
.astype(np.uint8)
|
||||
)
|
||||
matched_result = Image.fromarray(np_matched_result, mode="RGB")
|
||||
else:
|
||||
matched_result = Image.fromarray(np_image, mode="RGB")
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv2.erode(
|
||||
nm,
|
||||
kernel=np.ones((3, 3), dtype=np.uint8),
|
||||
iterations=int(mask_blur_radius / 2),
|
||||
)
|
||||
pmd = Image.fromarray(nmd, mode="L")
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
multiplied_blurred_init_mask = ImageChops.multiply(
|
||||
blurred_init_mask, self.pil_image.split()[-1]
|
||||
)
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
@staticmethod
|
||||
def sample_to_lowres_estimated_image(samples):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=samples.dtype,
|
||||
device=samples.device,
|
||||
)
|
||||
|
||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
def generate_initial_noise(self, seed, width, height):
|
||||
initial_noise = None
|
||||
if self.variation_amount > 0 or len(self.with_variations) > 0:
|
||||
# use fixed initial noise plus random noise per iteration
|
||||
set_seed(seed)
|
||||
initial_noise = self.get_noise(width, height)
|
||||
for v_seed, v_weight in self.with_variations:
|
||||
seed = v_seed
|
||||
set_seed(seed)
|
||||
next_noise = self.get_noise(width, height)
|
||||
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
|
||||
if self.variation_amount > 0:
|
||||
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
|
||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return (seed, initial_noise)
|
||||
|
||||
def get_perlin_noise(self, width, height):
|
||||
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
# round up to the nearest block of 8
|
||||
temp_width = int((width + 7) / 8) * 8
|
||||
temp_height = int((height + 7) / 8) * 8
|
||||
noise = torch.stack(
|
||||
[
|
||||
rand_perlin_2d(
|
||||
(temp_height, temp_width), (8, 8), device=self.model.device
|
||||
).to(fixdevice)
|
||||
for _ in range(input_channels)
|
||||
],
|
||||
dim=0,
|
||||
).to(self.model.device)
|
||||
return noise[0:4, 0:height, 0:width]
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(self.model.device)
|
||||
|
||||
return v2
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
logger.info(f"creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
def torch_dtype(self) -> torch.dtype:
|
||||
return torch.float16 if self.precision == "float16" else torch.float32
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self, width, height):
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(
|
||||
width // self.downsampling_factor, height // self.downsampling_factor
|
||||
)
|
||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user