mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
555 Commits
psyche/per
...
5.10.0dev1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bea9d037bc | ||
|
|
eed4260975 | ||
|
|
6a79b1c64c | ||
|
|
4cdfe3e30d | ||
|
|
df294db236 | ||
|
|
52e247cfe0 | ||
|
|
0a13640bf3 | ||
|
|
643b71f56c | ||
|
|
b745411866 | ||
|
|
c3ffb0feed | ||
|
|
f53ff5fa3c | ||
|
|
b2337b56bd | ||
|
|
fb777b4502 | ||
|
|
4c12f5a011 | ||
|
|
d4655ea21a | ||
|
|
752b62d0b5 | ||
|
|
9a0efb308d | ||
|
|
b4c276b50f | ||
|
|
db03c196a1 | ||
|
|
6bc36b697d | ||
|
|
b7d71d3028 | ||
|
|
fa1ebd9d2f | ||
|
|
eed5d02069 | ||
|
|
3650d91045 | ||
|
|
6c7d08cacb | ||
|
|
bb1c40f222 | ||
|
|
bfb117d0e0 | ||
|
|
b31c1022c3 | ||
|
|
a5851ca31c | ||
|
|
77bf5c15bb | ||
|
|
d26b7a1a12 | ||
|
|
595133463e | ||
|
|
6155f9ff9e | ||
|
|
7be87c8048 | ||
|
|
9868c3bfe3 | ||
|
|
8b299d0bac | ||
|
|
a44bfb4658 | ||
|
|
96fb5f6881 | ||
|
|
4109ea5324 | ||
|
|
f6c2ee5040 | ||
|
|
965753bf8b | ||
|
|
40c53ab95c | ||
|
|
aaa6211625 | ||
|
|
f6d770eac9 | ||
|
|
47cb61cd62 | ||
|
|
b0fdc8ae1c | ||
|
|
ed9b30efda | ||
|
|
168e5eeff0 | ||
|
|
7acaa86bdf | ||
|
|
96c0393fe7 | ||
|
|
403f795c5e | ||
|
|
c0f88a083e | ||
|
|
542b182899 | ||
|
|
3f58c68c09 | ||
|
|
e50c7e5947 | ||
|
|
4a83700fe4 | ||
|
|
c9992914d6 | ||
|
|
c25f6d1f84 | ||
|
|
a53e1ccf08 | ||
|
|
1af9930951 | ||
|
|
c276c1cbee | ||
|
|
c619348f29 | ||
|
|
c6f96613fc | ||
|
|
258bf736da | ||
|
|
0d75c99476 | ||
|
|
323d409fb6 | ||
|
|
f251722f56 | ||
|
|
7004fde41b | ||
|
|
c9dc27afbb | ||
|
|
efd14ec0e4 | ||
|
|
21ee2b6251 | ||
|
|
82dd2d508f | ||
|
|
ffb5f6c6a6 | ||
|
|
5c5fff9ecb | ||
|
|
9ca071819b | ||
|
|
b14d8e8192 | ||
|
|
3f12a43e75 | ||
|
|
5a59f6e3b8 | ||
|
|
60b5aef16a | ||
|
|
35222a8835 | ||
|
|
0e8b5484d5 | ||
|
|
454506c83e | ||
|
|
8f6ab67376 | ||
|
|
5afcc7778f | ||
|
|
325e07d330 | ||
|
|
a016bdc159 | ||
|
|
a14f0b2864 | ||
|
|
721483318a | ||
|
|
be04743649 | ||
|
|
92f0c28d6c | ||
|
|
a6b94e8ca4 | ||
|
|
00b11ef795 | ||
|
|
182580ff69 | ||
|
|
8e9d5c1187 | ||
|
|
99aac5870e | ||
|
|
c1b475c585 | ||
|
|
ec44e68cbf | ||
|
|
73dbebbcc3 | ||
|
|
09f971467d | ||
|
|
2c71b0e873 | ||
|
|
92f69ac463 | ||
|
|
3b154df71a | ||
|
|
64aa965160 | ||
|
|
d715c27d07 | ||
|
|
515084577c | ||
|
|
7596c07a64 | ||
|
|
98fd1d949b | ||
|
|
6312e6aa8f | ||
|
|
6435f11bae | ||
|
|
1c69b9b1fa | ||
|
|
731970ff88 | ||
|
|
038bac1614 | ||
|
|
ed9efe7740 | ||
|
|
ffa0beba7a | ||
|
|
75d793f1c4 | ||
|
|
2b086917e0 | ||
|
|
a9f2738086 | ||
|
|
3a56799ea5 | ||
|
|
3162ce94dc | ||
|
|
c0dc6ac4e1 | ||
|
|
fed1995525 | ||
|
|
5006e23456 | ||
|
|
2f063bddda | ||
|
|
23a26422fd | ||
|
|
434f195a96 | ||
|
|
6a4c2d692c | ||
|
|
5127a07cf9 | ||
|
|
0b4c6f0ab4 | ||
|
|
d8450033ea | ||
|
|
3938736bd8 | ||
|
|
fb2c7b9566 | ||
|
|
29449ec27d | ||
|
|
e38f778d28 | ||
|
|
f5e78436a8 | ||
|
|
6a15b5d9be | ||
|
|
a629102c87 | ||
|
|
848ade8ab8 | ||
|
|
2110feb01c | ||
|
|
f3e1821957 | ||
|
|
bbcf93089a | ||
|
|
66f41aa307 | ||
|
|
8a709766b3 | ||
|
|
efaa20a7a1 | ||
|
|
3e4c808b23 | ||
|
|
00e3931af4 | ||
|
|
08bea07f8b | ||
|
|
166d2f0e39 | ||
|
|
21f346717a | ||
|
|
f966fb8b9c | ||
|
|
c2b20a5387 | ||
|
|
bed9089fe6 | ||
|
|
d34a4f765c | ||
|
|
efe4708b8b | ||
|
|
7cb1f61a9e | ||
|
|
6e2ef34cba | ||
|
|
d208b99a47 | ||
|
|
47eeafa5cb | ||
|
|
0cb00fbe53 | ||
|
|
a7e8ed3bc2 | ||
|
|
22eb25be48 | ||
|
|
a077f3fefc | ||
|
|
c013a6e38d | ||
|
|
6cfeb71bed | ||
|
|
534f993023 | ||
|
|
67f9b6420c | ||
|
|
61bf065237 | ||
|
|
e78cf889ee | ||
|
|
5d13f0ba15 | ||
|
|
633b9afa46 | ||
|
|
f1889b259d | ||
|
|
ed21d0b57e | ||
|
|
df90da28e1 | ||
|
|
702054aa62 | ||
|
|
636ec1de6e | ||
|
|
063d07fd41 | ||
|
|
c78eac624e | ||
|
|
05de3b7a84 | ||
|
|
9cc2232b6f | ||
|
|
9fdc06b447 | ||
|
|
5ea3ec5cc8 | ||
|
|
f13a07ba6a | ||
|
|
a913f0163d | ||
|
|
f7cfbd1323 | ||
|
|
2806b60701 | ||
|
|
d8c3af624b | ||
|
|
feed44b68d | ||
|
|
247f3b5d67 | ||
|
|
8e14f9d971 | ||
|
|
bdb44ee48d | ||
|
|
b57f5330c5 | ||
|
|
ade3c015b4 | ||
|
|
7fe4d4c21a | ||
|
|
133a7fde55 | ||
|
|
6375214878 | ||
|
|
b9972be7f1 | ||
|
|
e61c5a3f26 | ||
|
|
8c633786f6 | ||
|
|
8703eea49b | ||
|
|
c8888be4c3 | ||
|
|
11963a65a4 | ||
|
|
ab6422fdf7 | ||
|
|
1f8632029e | ||
|
|
88a762474d | ||
|
|
e6dd721e33 | ||
|
|
2a09604baf | ||
|
|
f94f00ede0 | ||
|
|
37af281299 | ||
|
|
fc82775d7a | ||
|
|
9ed46f60b7 | ||
|
|
9a389e6b93 | ||
|
|
2ef1ecf381 | ||
|
|
41de112932 | ||
|
|
e9714fe476 | ||
|
|
3f29293e39 | ||
|
|
db1aa38e98 | ||
|
|
12717d4a4d | ||
|
|
1953f3cbcd | ||
|
|
3469fc9843 | ||
|
|
7cdd4187a9 | ||
|
|
ad66c101d2 | ||
|
|
28d3356710 | ||
|
|
81e70fb9d2 | ||
|
|
971c425734 | ||
|
|
b09008c530 | ||
|
|
f9f99f873d | ||
|
|
7f93f1b600 | ||
|
|
b1d336ce8a | ||
|
|
40c7be8f5d | ||
|
|
24218b34bf | ||
|
|
d970c6d6d5 | ||
|
|
e5308be0bb | ||
|
|
7d5687e9ff | ||
|
|
7adac4581a | ||
|
|
962db86cac | ||
|
|
d65ec0e250 | ||
|
|
7fdde5e84a | ||
|
|
895956bcfe | ||
|
|
f27d26cfa2 | ||
|
|
965bcba6c2 | ||
|
|
c9f2460ff2 | ||
|
|
5abbbf4b5b | ||
|
|
e66688edbf | ||
|
|
a519483f95 | ||
|
|
75c91604bb | ||
|
|
53bdaba7b6 | ||
|
|
f3f405ca77 | ||
|
|
dda69950a7 | ||
|
|
b2198b9fa7 | ||
|
|
02b91e8e7b | ||
|
|
09bf7c35eb | ||
|
|
deb9a65b3d | ||
|
|
5be9a7227c | ||
|
|
bb9f886bd4 | ||
|
|
46520946f8 | ||
|
|
830880a6fc | ||
|
|
63b94a8ff3 | ||
|
|
f12924a1e1 | ||
|
|
f8e51c86f5 | ||
|
|
654e992630 | ||
|
|
21f247f499 | ||
|
|
8bcd9fe4b7 | ||
|
|
c84a646735 | ||
|
|
b52f8121af | ||
|
|
05bed3fddd | ||
|
|
87ea20192f | ||
|
|
2f9c95c462 | ||
|
|
47cadbb48e | ||
|
|
23518b9830 | ||
|
|
94dcf391a6 | ||
|
|
637b93d2d8 | ||
|
|
565b160060 | ||
|
|
e7a60c01ed | ||
|
|
4b54ccc29c | ||
|
|
c4183ec98c | ||
|
|
5a9cbe35e0 | ||
|
|
df18fe0298 | ||
|
|
e5591d145f | ||
|
|
371c187fc3 | ||
|
|
bdd0b90769 | ||
|
|
4377158503 | ||
|
|
c8c27079ed | ||
|
|
d8b9a8d0dd | ||
|
|
39a4608d15 | ||
|
|
cd2d5431db | ||
|
|
c04cdd9779 | ||
|
|
b86ac5e049 | ||
|
|
e982c95687 | ||
|
|
665236bb79 | ||
|
|
0eeb0dd67b | ||
|
|
28c74cbe38 | ||
|
|
7414f68acc | ||
|
|
a984462b80 | ||
|
|
c6c2567203 | ||
|
|
f05c8b909f | ||
|
|
73330a1308 | ||
|
|
6f568d48ed | ||
|
|
81a97f3796 | ||
|
|
3f9535d2f9 | ||
|
|
83bfbdcad4 | ||
|
|
729428084c | ||
|
|
523a932ecc | ||
|
|
21be7d7157 | ||
|
|
a29fb18c0b | ||
|
|
aed446f013 | ||
|
|
e81c9b0d6e | ||
|
|
f45400a275 | ||
|
|
89f457c486 | ||
|
|
30ed09a36e | ||
|
|
3334652acc | ||
|
|
e83536f396 | ||
|
|
97593f95f6 | ||
|
|
7f14cee17e | ||
|
|
0a836d6fc1 | ||
|
|
54e781d5bb | ||
|
|
aa71d0c817 | ||
|
|
07313e429d | ||
|
|
bad5023238 | ||
|
|
73a0d2c06c | ||
|
|
918e9c8ccc | ||
|
|
1e388e9ca4 | ||
|
|
5b84d45932 | ||
|
|
dc3f1184b2 | ||
|
|
87438bcad7 | ||
|
|
afd894fd04 | ||
|
|
df305c0b99 | ||
|
|
deecb7f3c3 | ||
|
|
dd5f353465 | ||
|
|
a8759ea0a6 | ||
|
|
3ff529c718 | ||
|
|
3b0fecafb0 | ||
|
|
099011000f | ||
|
|
155daa3137 | ||
|
|
c493e223cf | ||
|
|
124ca23f8b | ||
|
|
a8023cbcb6 | ||
|
|
b733d3897e | ||
|
|
ef95b37ace | ||
|
|
4feff5a185 | ||
|
|
6c8dc32d5c | ||
|
|
e5da808b2f | ||
|
|
7d3434da62 | ||
|
|
4cc70d9f16 | ||
|
|
7988bc1a59 | ||
|
|
1756d885f6 | ||
|
|
9ec4d968aa | ||
|
|
76c09301f9 | ||
|
|
1cf8749754 | ||
|
|
5d6c468833 | ||
|
|
80b3f44ae8 | ||
|
|
c77c12aa1d | ||
|
|
731992c5ec | ||
|
|
c259899bf4 | ||
|
|
f62b9ad919 | ||
|
|
57533657f9 | ||
|
|
e35537e60a | ||
|
|
be53b89203 | ||
|
|
a215eeaabf | ||
|
|
d86b392bfd | ||
|
|
3e9e45b177 | ||
|
|
907d960745 | ||
|
|
bfdace6437 | ||
|
|
a89d68b93a | ||
|
|
59a8c0d441 | ||
|
|
d5d08f6569 | ||
|
|
8a4282365e | ||
|
|
b9c7bc8b0e | ||
|
|
0f45ee04a2 | ||
|
|
839a791509 | ||
|
|
f03a2bf03f | ||
|
|
4136817d30 | ||
|
|
7f0452173b | ||
|
|
8e46b03f09 | ||
|
|
9045237bfb | ||
|
|
58959a18cb | ||
|
|
e51588197f | ||
|
|
c5319ac48c | ||
|
|
50657650c2 | ||
|
|
f657c95e45 | ||
|
|
2d3a2f9842 | ||
|
|
008837642e | ||
|
|
1a84a2fb7e | ||
|
|
b87febcf4c | ||
|
|
95a9bb6c7b | ||
|
|
93ec9a048f | ||
|
|
ec6cea6705 | ||
|
|
bfbcaad8c2 | ||
|
|
3694158434 | ||
|
|
814fb939c0 | ||
|
|
4cb73e6c19 | ||
|
|
e8aed67cf1 | ||
|
|
f56dd01419 | ||
|
|
ed9cd6a7a2 | ||
|
|
c44c28ec4c | ||
|
|
e1f7359171 | ||
|
|
3e97d49a69 | ||
|
|
c12585e52d | ||
|
|
b39774a57c | ||
|
|
8988539cd5 | ||
|
|
88c68e8016 | ||
|
|
5073c7d0a3 | ||
|
|
84e86819b8 | ||
|
|
440e3e01ac | ||
|
|
c2302f7ab1 | ||
|
|
2594eed1af | ||
|
|
e8db1c1d5a | ||
|
|
d5c5e8e8ed | ||
|
|
518a7c941f | ||
|
|
bdafe53f2e | ||
|
|
cf0cbaf0ae | ||
|
|
ac6fc6eccb | ||
|
|
07d65b8fd1 | ||
|
|
3c2e6378ca | ||
|
|
445f122f37 | ||
|
|
8c0ee9c48f | ||
|
|
0eb237ac64 | ||
|
|
9aa04f0bea | ||
|
|
76e2f41ec7 | ||
|
|
1353c3301a | ||
|
|
bf209663ac | ||
|
|
04b96dd7b4 | ||
|
|
79b2c68853 | ||
|
|
aac456527e | ||
|
|
c88b835373 | ||
|
|
9da116fd3d | ||
|
|
201d7f1fdb | ||
|
|
17a5b1bd28 | ||
|
|
a409aec00f | ||
|
|
b0593eda92 | ||
|
|
9acb24914f | ||
|
|
ab4433da2f | ||
|
|
d4423aa16f | ||
|
|
1f6430c1b0 | ||
|
|
8e28888bc4 | ||
|
|
b6b21dbcbf | ||
|
|
7b48ef2264 | ||
|
|
9c542ed655 | ||
|
|
4c02ba908a | ||
|
|
82293ae3b2 | ||
|
|
f1fde792ee | ||
|
|
e82393f7ed | ||
|
|
d5211a8088 | ||
|
|
3b095b5945 | ||
|
|
34959ef573 | ||
|
|
7f10f8f96a | ||
|
|
f2689598c0 | ||
|
|
551c78d9f3 | ||
|
|
0cfd713b93 | ||
|
|
45f5d7617a | ||
|
|
f49df7d327 | ||
|
|
87ed0ed48a | ||
|
|
d445c88e4c | ||
|
|
c15c43ed2a | ||
|
|
d2f8db9745 | ||
|
|
c1cf01a038 | ||
|
|
2bfb4fc79c | ||
|
|
d037d8f9aa | ||
|
|
d5401e8443 | ||
|
|
d193e4f02a | ||
|
|
ec493e30ee | ||
|
|
081b931edf | ||
|
|
8cd7035494 | ||
|
|
4de6fd3ae6 | ||
|
|
3feb1a6600 | ||
|
|
ea2320c57b | ||
|
|
0ad0016c2d | ||
|
|
c2a3c66e49 | ||
|
|
c0a0d20935 | ||
|
|
028d8d8ead | ||
|
|
657095d2e2 | ||
|
|
1c47dc997e | ||
|
|
a3de6b6165 | ||
|
|
e57f0ff055 | ||
|
|
0362bd5a06 | ||
|
|
feee4c49a2 | ||
|
|
42e052d6f2 | ||
|
|
b03e429b26 | ||
|
|
7399909029 | ||
|
|
c8aaf5e76b | ||
|
|
0cdf7a7048 | ||
|
|
41985487d3 | ||
|
|
41d5a17114 | ||
|
|
14f9d5b6bc | ||
|
|
eec4bdb038 | ||
|
|
f3dd44044a | ||
|
|
61a22eb8cb | ||
|
|
03ca83fe13 | ||
|
|
8f1e25c387 | ||
|
|
29cf4bc002 | ||
|
|
9428642806 | ||
|
|
8620572524 | ||
|
|
f44c7e824d | ||
|
|
c5b8bde285 | ||
|
|
4c86a7ecbf | ||
|
|
b9f9d1c152 | ||
|
|
7567ee2adf | ||
|
|
0e632dbc5c | ||
|
|
49191709a0 | ||
|
|
3af7fc26fa | ||
|
|
a36a627f83 | ||
|
|
b31c71f302 | ||
|
|
5302d4890f | ||
|
|
766b752572 | ||
|
|
7feae5e5ce | ||
|
|
26730ca702 | ||
|
|
1e2c7c51b5 | ||
|
|
da2b6815ac | ||
|
|
68d14de3ee | ||
|
|
38991ffc35 | ||
|
|
f345c0fabc | ||
|
|
ca23b5337e | ||
|
|
35910d3952 | ||
|
|
6f1dcf385b | ||
|
|
84c9ecc83f | ||
|
|
52aa839b7e | ||
|
|
316ed1d478 | ||
|
|
3519e8ae39 | ||
|
|
82f645c7a1 | ||
|
|
cc36cfb617 | ||
|
|
ded8a84284 | ||
|
|
94771ea626 | ||
|
|
51d661023e | ||
|
|
d215829b91 | ||
|
|
fad6c67f01 | ||
|
|
f366640d46 | ||
|
|
36a3fba8cb | ||
|
|
b2ff83092f | ||
|
|
d2db38a5b9 | ||
|
|
fa988a6273 | ||
|
|
149f60946c | ||
|
|
ee9d620a36 | ||
|
|
4e8ce4abab | ||
|
|
d40f2fa37c | ||
|
|
933f4f6857 | ||
|
|
f499b2db7b | ||
|
|
706aaf7460 | ||
|
|
4a706d00bb | ||
|
|
2a8bff601f | ||
|
|
3f0e3192f6 | ||
|
|
c65147e2ff | ||
|
|
1c14e257a3 | ||
|
|
fe24217082 | ||
|
|
aee847065c | ||
|
|
525da3257c | ||
|
|
559654f0ca | ||
|
|
5d33874d58 | ||
|
|
0063315139 | ||
|
|
1cbd609860 | ||
|
|
047c643295 | ||
|
|
d1e03aa1c5 | ||
|
|
1bb8edf57e | ||
|
|
a3e78f0db6 | ||
|
|
80d38c0e47 | ||
|
|
22362350dc | ||
|
|
275d891f48 | ||
|
|
3848e1926b |
@@ -1,9 +1,11 @@
|
||||
*
|
||||
!invokeai
|
||||
!pyproject.toml
|
||||
!uv.lock
|
||||
!docker/docker-entrypoint.sh
|
||||
!LICENSE
|
||||
|
||||
**/dist
|
||||
**/node_modules
|
||||
**/__pycache__
|
||||
**/*.egg-info
|
||||
**/*.egg-info
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||
182580ff6970caed400be178c5b888514b75d7f2
|
||||
8e9d5c1187b0d36da80571ce4c8ba9b3a37b6c46
|
||||
99aac5870e1092b182e6c5f21abcaab6936a4ad1
|
||||
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -2,4 +2,5 @@
|
||||
# Only affects text files and ignores other file types.
|
||||
# For more info see: https://www.aleksandrhovhannisyan.com/blog/crlf-vs-lf-normalizing-line-endings-in-git/
|
||||
* text=auto
|
||||
docker/** text eol=lf
|
||||
docker/** text eol=lf
|
||||
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
6
.github/CODEOWNERS
vendored
6
.github/CODEOWNERS
vendored
@@ -1,12 +1,12 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
@@ -22,7 +22,7 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername @jazzhaiku
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
|
||||
13
.github/workflows/build-container.yml
vendored
13
.github/workflows/build-container.yml
vendored
@@ -76,9 +76,6 @@ jobs:
|
||||
latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
@@ -100,10 +97,12 @@ jobs:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
build-args: |
|
||||
GPU_DRIVER=${{ matrix.gpu-driver }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
# cache-from: |
|
||||
# type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
# type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
# cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
7
.github/workflows/frontend-checks.yml
vendored
7
.github/workflows/frontend-checks.yml
vendored
@@ -44,7 +44,12 @@ jobs:
|
||||
- name: check for changed frontend files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
|
||||
7
.github/workflows/frontend-tests.yml
vendored
7
.github/workflows/frontend-tests.yml
vendored
@@ -44,7 +44,12 @@ jobs:
|
||||
- name: check for changed frontend files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
|
||||
28
.github/workflows/python-checks.yml
vendored
28
.github/workflows/python-checks.yml
vendored
@@ -34,6 +34,9 @@ on:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
env:
|
||||
# uv requires a venv by default - but for this, we can simply use the system python
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
@@ -43,7 +46,12 @@ jobs:
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
@@ -52,25 +60,19 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff==0.6.0
|
||||
shell: bash
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff check --output-format=github .
|
||||
run: uv tool run ruff@0.11.2 check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff format --check .
|
||||
run: uv tool run ruff@0.11.2 format --check .
|
||||
shell: bash
|
||||
|
||||
40
.github/workflows/python-tests.yml
vendored
40
.github/workflows/python-tests.yml
vendored
@@ -39,24 +39,15 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
- '3.12'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
os: ubuntu-24.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
@@ -70,14 +61,22 @@ jobs:
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
# https://github.com/nschloe/action-cached-lfs-checkout
|
||||
uses: nschloe/action-cached-lfs-checkout@f46300cd8952454b9f0a21a3d133d4bd5684cfc2
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
@@ -86,20 +85,25 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable ".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
|
||||
27
.github/workflows/typegen-checks.yml
vendored
27
.github/workflows/typegen-checks.yml
vendored
@@ -42,24 +42,37 @@ jobs:
|
||||
- name: check for changed files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
# Pinned to the _hash_ for v45.0.9 to prevent supply-chain attacks.
|
||||
# See:
|
||||
# - CVE-2025-30066
|
||||
# - https://www.stepsecurity.io/blog/harden-runner-detection-tj-actions-changed-files-action-is-compromised
|
||||
# - https://github.com/tj-actions/changed-files/issues/2463
|
||||
uses: tj-actions/changed-files@a284dc1814e3fd07f2e34267fc8f81227ed29fb8
|
||||
with:
|
||||
files_yaml: |
|
||||
src:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: '3.11'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
python-version: '3.11'
|
||||
|
||||
- name: install python dependencies
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip3 install --use-pep517 --editable="."
|
||||
env:
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable .
|
||||
|
||||
- name: install frontend dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
@@ -72,7 +85,7 @@ jobs:
|
||||
|
||||
- name: generate schema
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: make frontend-typegen
|
||||
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
shell: bash
|
||||
|
||||
- name: compare files
|
||||
|
||||
@@ -1,62 +1,6 @@
|
||||
# syntax=docker/dockerfile:1.4
|
||||
|
||||
## Builder stage
|
||||
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
build-essential \
|
||||
git
|
||||
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
ARG TARGETPLATFORM="linux/amd64"
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt
|
||||
USER ubuntu
|
||||
|
||||
# Install python and create the venv
|
||||
RUN uv python install ${PYTHON_VERSION} && \
|
||||
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
COPY invokeai ./invokeai
|
||||
COPY pyproject.toml ./
|
||||
|
||||
# Editable mode helps use the same image for development:
|
||||
# the local working copy can be bind-mounted into the image
|
||||
# at path defined by ${INVOKEAI_SRC}
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
|
||||
fi && \
|
||||
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
|
||||
|
||||
#### Build the Web UI ------------------------------------
|
||||
#### Web UI ------------------------------------
|
||||
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
@@ -70,68 +14,89 @@ RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN npx vite build
|
||||
|
||||
#### Runtime stage ---------------------------------------
|
||||
## Backend ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
FROM library/ubuntu:24.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
--mount=type=cache,target=/var/lib/apt \
|
||||
apt update && apt install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
git \
|
||||
gosu \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev
|
||||
|
||||
RUN apt update && apt install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
vim \
|
||||
tmux \
|
||||
ncdu \
|
||||
iotop \
|
||||
bzip2 \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
ENV \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
VIRTUAL_ENV=/opt/venv \
|
||||
INVOKEAI_SRC=/opt/invokeai \
|
||||
PYTHON_VERSION=3.12 \
|
||||
UV_PYTHON=3.12 \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_MANAGED_PYTHON=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_PROJECT_ENVIRONMENT=/opt/venv \
|
||||
UV_INDEX="https://download.pytorch.org/whl/cu124" \
|
||||
INVOKEAI_ROOT=/invokeai \
|
||||
INVOKEAI_HOST=0.0.0.0 \
|
||||
INVOKEAI_PORT=9090 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
CONTAINER_UID=${CONTAINER_UID:-1000} \
|
||||
CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
ARG GPU_DRIVER=cuda
|
||||
|
||||
# Install `uv` for package management
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.6.9 /uv /uvx /bin/
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
# Install python & allow non-root user to use it by traversing the /root dir without read permissions
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install ${PYTHON_VERSION} && \
|
||||
# chmod --recursive a+rX /root/.local/share/uv/python
|
||||
chmod 711 /root
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
|
||||
# bind-mount instead of copy to defer adding sources to the image until next layer.
|
||||
#
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
# this is just to get the package manager to recognize that the project exists, without making changes to the docker layer
|
||||
--mount=type=bind,source=invokeai/version,target=invokeai/version \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then UV_INDEX="https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then UV_INDEX="https://download.pytorch.org/whl/rocm6.2"; \
|
||||
fi && \
|
||||
uv sync --frozen
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
# Link amdgpu.ids for ROCm builds
|
||||
# contributed by https://github.com/Rubonnek
|
||||
RUN mkdir -p "/opt/amdgpu/share/libdrm" &&\
|
||||
ln -s "/usr/share/libdrm/amdgpu.ids" "/opt/amdgpu/share/libdrm/amdgpu.ids"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
CMD ["invokeai-web"]
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4, does not work with podman
|
||||
COPY --link --from=web-builder /build/dist ${INVOKEAI_SRC}/invokeai/frontend/web/dist
|
||||
|
||||
# add sources last to minimize image changes on code changes
|
||||
COPY invokeai ${INVOKEAI_SRC}/invokeai
|
||||
138
docs/RELEASE.md
138
docs/RELEASE.md
@@ -1,41 +1,50 @@
|
||||
# Release Process
|
||||
|
||||
The app is published in twice, in different build formats.
|
||||
The Invoke application is published as a python package on [PyPI]. This includes both a source distribution and built distribution (a wheel).
|
||||
|
||||
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
|
||||
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
|
||||
Most users install it with the [Launcher](https://github.com/invoke-ai/launcher/), others with `pip`.
|
||||
|
||||
The launcher uses GitHub as the source of truth for available releases.
|
||||
|
||||
## Broad Strokes
|
||||
|
||||
- Merge all changes and bump the version in the codebase.
|
||||
- Tag the release commit.
|
||||
- Wait for the release workflow to complete.
|
||||
- Approve the PyPI publish jobs.
|
||||
- Write GH release notes.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out.
|
||||
|
||||
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
|
||||
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
|
||||
|
||||
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
|
||||
|
||||
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
|
||||
It is triggered on **tag push**, when the tag matches `v*`.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
Ensure all commits that should be in the release are merged, and you have pulled them locally.
|
||||
|
||||
The release may also be dispatched [manually].
|
||||
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
|
||||
|
||||
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
|
||||
|
||||
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
The workflow consists of a number of concurrently-run checks and tests, then two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
|
||||
This job ensures that the `invokeai` python package version specifier matches the tag for the release. The version specifier is pulled from the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
@@ -43,62 +52,52 @@ This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
Next, these jobs run and must pass. They are the same jobs that are run for every PR.
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
- **`typegen-checks`**: ensures the frontend and backend types are synced
|
||||
|
||||
#### `build-installer` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
- **`dist`**: the python distribution, to be published on PyPI
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the legacy install scripts
|
||||
|
||||
You don't need to download either of these files.
|
||||
|
||||
> The legacy install scripts are no longer used, but we haven't updated the workflow to skip building them.
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval. Time to test the installer.
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
|
||||
Because the installer pulls from PyPI, and we haven't published to PyPI yet, you will need to install from the wheel:
|
||||
It's possible to test the python package before it gets published to PyPI. We've never had problems with it, so it's not necessary to do this.
|
||||
|
||||
- Download and unzip `dist.zip` and the installer from the **Summary** tab of the workflow
|
||||
- Run the installer script using the `--wheel` CLI arg, pointing at the wheel:
|
||||
But, if you want to be extra-super careful, here's how to test it:
|
||||
|
||||
```sh
|
||||
./install.sh --wheel ../InvokeAI-4.0.0rc6-py3-none-any.whl
|
||||
```
|
||||
|
||||
- Install to a temporary directory so you get the new user experience
|
||||
- Download a model and generate
|
||||
|
||||
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation as if the installer got the wheel from PyPI.
|
||||
- Download the `dist.zip` build artifact from the `build-installer` job
|
||||
- Unzip it and find the wheel file
|
||||
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/) - but instead of installing from PyPI, install from the wheel
|
||||
- Test the app
|
||||
|
||||
##### Something isn't right
|
||||
|
||||
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?).
|
||||
|
||||
Now you can start from the top:
|
||||
|
||||
- Fix the issues and PR the fixes per usual
|
||||
- Get the PR approved and merged per usual
|
||||
- Switch to `main` and pull in the fixes
|
||||
- Run `make tag-release` to move the tag to `HEAD` (which has the fixes) and kick off the release workflow again
|
||||
- Re-do the sanity check
|
||||
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?) and start over.
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
The publish jobs will not run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
|
||||
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi`)
|
||||
- Select the environment (either `testpypi` or `pypi` - typically you select both)
|
||||
- Click **Approve and deploy**
|
||||
|
||||
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
|
||||
@@ -113,46 +112,33 @@ If there are no incidents, contact @hipsterusername or @lstein, who have owner a
|
||||
|
||||
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
|
||||
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release for some reason:
|
||||
|
||||
If approved and successful, you could try out the test release like this:
|
||||
|
||||
```sh
|
||||
# Create a new virtual environment
|
||||
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
|
||||
# Install the distribution from Test PyPI
|
||||
pip install --index-url https://test.pypi.org/simple/ invokeai
|
||||
# Run and test the app
|
||||
invokeai-web
|
||||
# Cleanup
|
||||
deactivate
|
||||
rm -rf ~/.test-invokeai-dist
|
||||
```
|
||||
- Approve this publish job without approving the prod publish
|
||||
- Let it finish
|
||||
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/), making sure to use the Test PyPI index URL: `https://test.pypi.org/simple/`
|
||||
- Test the app
|
||||
|
||||
#### `publish-pypi` Job
|
||||
|
||||
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
|
||||
|
||||
## Publish the GitHub Release with installer
|
||||
It's a good idea to wait to approve and run this job until you have the release notes ready!
|
||||
|
||||
Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
## Prep and publish the GitHub Release
|
||||
|
||||
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
|
||||
1. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
|
||||
1. Use `scripts/get_external_contributions.py` to get a list of external contributions to shout out in the release notes.
|
||||
1. Upload the zip file created in **`build`** job into the Assets section of the release notes.
|
||||
1. Check **Set as a pre-release** if it's a pre-release.
|
||||
1. Check **Create a discussion for this release**.
|
||||
1. Publish the release.
|
||||
1. Announce the release in Discord.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
2. The **Generate release notes** button automatically inserts the changelog and new contributors. Make sure to select the correct tags for this release and the last stable release. GH often selects the wrong tags - do this manually.
|
||||
3. Write the release notes, describing important changes. Contributions from community members should be shouted out. Use the GH-generated changelog to see all contributors. If there are Weblate translation updates, open that PR and shout out every person who contributed a translation.
|
||||
4. Check **Set as a pre-release** if it's a pre-release.
|
||||
5. Approve and wait for the `publish-pypi` job to finish if you haven't already.
|
||||
6. Publish the GH release.
|
||||
7. Post the release in Discord in the [releases](https://discord.com/channels/1020123559063990373/1149260708098359327) channel with abbreviated notes. For example:
|
||||
> Invoke v5.7.0 (stable): <https://github.com/invoke-ai/InvokeAI/releases/tag/v5.7.0>
|
||||
>
|
||||
> It's a pretty big one - Form Builder, Metadata Nodes (thanks @SkunkWorxDark!), and much more.
|
||||
8. Right click the message in releases and copy the link to it. Then, post that link in the [new-release-discussion](https://discord.com/channels/1020123559063990373/1149506274971631688) channel. For example:
|
||||
> Invoke v5.7.0 (stable): <https://discord.com/channels/1020123559063990373/1149260708098359327/1344521744916021248>
|
||||
|
||||
## Manual Release
|
||||
|
||||
@@ -160,12 +146,10 @@ The `release` workflow can be dispatched manually. You must dispatch the workflo
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
|
||||
[Test PyPI]: https://test.pypi.org/
|
||||
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
|
||||
[ncipollo/release-action]: https://github.com/ncipollo/release-action
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
|
||||
@@ -18,9 +18,19 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
|
||||
|
||||
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
3. This repository uses Git LFS to manage large files. To ensure all assets are downloaded:
|
||||
- Install git-lfs → [Download here](https://git-lfs.com/)
|
||||
- Enable automatic LFS fetching for this repository:
|
||||
```shell
|
||||
git config lfs.fetchinclude "*"
|
||||
```
|
||||
- Fetch files from LFS (only needs to be done once; subsequent `git pull` will fetch changes automatically):
|
||||
```
|
||||
git lfs pull
|
||||
```
|
||||
4. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
|
||||
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
5. Follow the [manual install][manual install link] guide, with some modifications to the install command:
|
||||
|
||||
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
|
||||
|
||||
@@ -31,22 +41,22 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
|
||||
```
|
||||
|
||||
5. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
|
||||
This is because the UI build is not distributed with the source code. You need to build it manually. End the running server instance.
|
||||
|
||||
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
|
||||
|
||||
6. Install the frontend dev toolchain:
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (v20+)
|
||||
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
7. Do a production build of the frontend:
|
||||
8. Do a production build of the frontend:
|
||||
|
||||
```sh
|
||||
cd <PATH_TO_INVOKEAI_REPO>/invokeai/frontend/web
|
||||
@@ -54,7 +64,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
pnpm build
|
||||
```
|
||||
|
||||
8. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
9. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
|
||||
|
||||
## Updating the UI
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ It is possible to fine-tune the settings for best performance or if you still ge
|
||||
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
|
||||
|
||||
- Partial model loading (`enable_partial_loading`)
|
||||
- PyTorch CUDA allocator config (`pytorch_cuda_alloc_conf`)
|
||||
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
|
||||
- Working memory (`device_working_mem_gb`)
|
||||
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
|
||||
@@ -51,6 +52,16 @@ As described above, you can enable partial model loading by adding this line to
|
||||
enable_partial_loading: true
|
||||
```
|
||||
|
||||
### PyTorch CUDA allocator config
|
||||
|
||||
The PyTorch CUDA allocator's behavior can be configured using the `pytorch_cuda_alloc_conf` config. Tuning the allocator configuration can help to reduce the peak reserved VRAM. The optimal configuration is dependent on many factors (e.g. device type, VRAM, CUDA driver version, etc.), but switching from PyTorch's native allocator to using CUDA's built-in allocator works well on many systems. To try this, add the following line to your `invokeai.yaml` file:
|
||||
|
||||
```yaml
|
||||
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
|
||||
```
|
||||
|
||||
A more complete explanation of the available configuration options is [here](https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf).
|
||||
|
||||
### Dynamic RAM and VRAM cache sizes
|
||||
|
||||
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
|
||||
@@ -75,24 +86,26 @@ But, if your GPU has enough VRAM to hold models fully, you might get a perf boos
|
||||
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
|
||||
# might be a good value to achieve aggressive model caching.
|
||||
max_cache_ram_gb: 28
|
||||
|
||||
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
|
||||
# consideration the VRAM used by other processes).
|
||||
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
|
||||
# `device_working_mem_gb`.
|
||||
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
|
||||
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
|
||||
# auxiliary models, 18GB might be a good value.
|
||||
max_cache_vram_gb: 18
|
||||
# You can override the default value by setting `max_cache_vram_gb`.
|
||||
# CAUTION: Most users should not manually set this value. See warning below.
|
||||
max_cache_vram_gb: 16
|
||||
```
|
||||
|
||||
!!! tip "Max safe value for `max_cache_vram_gb`"
|
||||
!!! warning "Max safe value for `max_cache_vram_gb`"
|
||||
|
||||
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
Most users should not manually configure the `max_cache_vram_gb`. This configuration value takes precedence over the `device_working_mem_gb` and any operations that explicitly reserve additional working memory (e.g. VAE decode). As such, manually configuring it increases the likelihood of encountering out-of-memory errors.
|
||||
|
||||
For users who wish to configure `max_cache_vram_gb`, the max safe value can be determined by subtracting `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
|
||||
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
|
||||
|
||||
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
|
||||
|
||||
Most users who override `max_cache_vram_gb` are doing so because they wish to use significantly less VRAM, and should be setting `max_cache_vram_gb` to a value significantly less than the 'max safe value'.
|
||||
|
||||
### Working memory
|
||||
|
||||
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
|
||||
|
||||
@@ -43,10 +43,10 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
3. Create a virtual environment in that directory:
|
||||
|
||||
```sh
|
||||
uv venv --relocatable --prompt invoke --python 3.11 --python-preference only-managed .venv
|
||||
uv venv --relocatable --prompt invoke --python 3.12 --python-preference only-managed .venv
|
||||
```
|
||||
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.11. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
This command creates a portable virtual environment at `.venv` complete with a portable python 3.12. It doesn't matter if your system has no python installed, or has a different version - `uv` will handle everything.
|
||||
|
||||
4. Activate the virtual environment:
|
||||
|
||||
@@ -88,13 +88,13 @@ The following commands vary depending on the version of Invoke being installed a
|
||||
8. Install the `invokeai` package. Substitute the package specifier and version.
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --force-reinstall
|
||||
```
|
||||
|
||||
If you determined you needed to use a `PyPI` index URL in the previous step, you'll need to add `--index=<INDEX_URL>` like this:
|
||||
|
||||
```sh
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.11 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
uv pip install <PACKAGE_SPECIFIER>==<VERSION> --python 3.12 --python-preference only-managed --index=<INDEX_URL> --force-reinstall
|
||||
```
|
||||
|
||||
9. Deactivate and reactivate your venv so that the invokeai-specific commands become available in the environment:
|
||||
|
||||
@@ -41,7 +41,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 or 3.11. If you don't already have one of these versions installed, we suggest installing 3.11, as it will be supported for longer.
|
||||
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
@@ -49,19 +49,19 @@ Check that your system has an up-to-date Python installed by running `python3 --
|
||||
|
||||
=== "Windows"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- Install python with [an official installer].
|
||||
- The installer includes an option to add python to your PATH. Be sure to enable this. If you missed it, re-run the installer, choose to modify an existing installation, and tick that checkbox.
|
||||
- You may need to install [Microsoft Visual C++ Redistributable].
|
||||
|
||||
=== "macOS"
|
||||
|
||||
- Install python 3.11 with [an official installer].
|
||||
- Install python with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
- Installing python varies depending on your system. On Ubuntu, you can use the [deadsnakes PPA](https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa).
|
||||
- Installing python varies depending on your system. We recommend [using `uv` to manage your python installation](https://docs.astral.sh/uv/concepts/python-versions/#installing-a-python-version).
|
||||
- You'll need to install `libglib2.0-0` and `libgl1-mesa-glx` for OpenCV to work. For example, on a Debian system: `sudo apt update && sudo apt install -y libglib2.0-0 libgl1-mesa-glx`
|
||||
|
||||
## Drivers
|
||||
|
||||
@@ -36,7 +36,14 @@ from invokeai.app.services.style_preset_images.style_preset_images_disk import S
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -83,6 +90,7 @@ class ApiDependencies:
|
||||
|
||||
model_images_folder = config.models_path
|
||||
style_presets_folder = config.style_presets_path
|
||||
workflow_thumbnails_folder = config.workflow_thumbnails_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
@@ -99,10 +107,25 @@ class ApiDependencies:
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
ObjectSerializerDisk[torch.Tensor](
|
||||
output_folder / "tensors",
|
||||
safe_globals=[torch.Tensor],
|
||||
ephemeral=True,
|
||||
),
|
||||
max_cache_size=0,
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
ObjectSerializerDisk[ConditioningFieldData](
|
||||
output_folder / "conditioning",
|
||||
safe_globals=[
|
||||
ConditioningFieldData,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
@@ -120,6 +143,7 @@ class ApiDependencies:
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -147,6 +171,7 @@ class ApiDependencies:
|
||||
conditioning=conditioning,
|
||||
style_preset_records=style_preset_records,
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
workflow_thumbnails=workflow_thumbnails,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
124
invokeai/app/api/extract_metadata_from_image.py
Normal file
124
invokeai/app/api/extract_metadata_from_image.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutIDValidator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMetadata:
|
||||
invokeai_metadata: str | None
|
||||
invokeai_workflow: str | None
|
||||
invokeai_graph: str | None
|
||||
|
||||
|
||||
def extract_metadata_from_image(
|
||||
pil_image: Image.Image,
|
||||
invokeai_metadata_override: str | None,
|
||||
invokeai_workflow_override: str | None,
|
||||
invokeai_graph_override: str | None,
|
||||
logger: logging.Logger,
|
||||
) -> ExtractedMetadata:
|
||||
"""
|
||||
Extracts the "invokeai_metadata", "invokeai_workflow", and "invokeai_graph" data embedded in the PIL Image.
|
||||
|
||||
These items are stored as stringified JSON in the image file's metadata, so we need to do some parsing to validate
|
||||
them. Once parsed, the values are returned as they came (as strings), or None if they are not present or invalid.
|
||||
|
||||
In some situations, we may prefer to override the values extracted from the image file with some other values.
|
||||
|
||||
For example, when uploading an image via API, the client can optionally provide the metadata directly in the request,
|
||||
as opposed to embedding it in the image file. In this case, the client-provided metadata will be used instead of the
|
||||
metadata embedded in the image file.
|
||||
|
||||
Args:
|
||||
pil_image: The PIL Image object.
|
||||
invokeai_metadata_override: The metadata override provided by the client.
|
||||
invokeai_workflow_override: The workflow override provided by the client.
|
||||
invokeai_graph_override: The graph override provided by the client.
|
||||
logger: The logger to use for debug logging.
|
||||
|
||||
Returns:
|
||||
ExtractedMetadata: The extracted metadata, workflow, and graph.
|
||||
"""
|
||||
|
||||
# The fallback value for metadata is None.
|
||||
stringified_metadata: str | None = None
|
||||
|
||||
# Use the metadata override if provided, else attempt to extract it from the image file.
|
||||
metadata_raw = invokeai_metadata_override or pil_image.info.get("invokeai_metadata", None)
|
||||
|
||||
# If the metadata is present in the image file, we will attempt to parse it as JSON. When we create images,
|
||||
# we always store metadata as a stringified JSON dict. So, we expect it to be a string here.
|
||||
if isinstance(metadata_raw, str):
|
||||
try:
|
||||
# Must be a JSON string
|
||||
metadata_parsed = json.loads(metadata_raw)
|
||||
# Must be a dict
|
||||
if isinstance(metadata_parsed, dict):
|
||||
# Looks good, overwrite the fallback value
|
||||
stringified_metadata = metadata_raw
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse metadata for uploaded image, {e}")
|
||||
pass
|
||||
|
||||
# We expect the workflow, if embedded in the image, to be a JSON-stringified WorkflowWithoutID. We will store it
|
||||
# as a string.
|
||||
workflow_raw: str | None = invokeai_workflow_override or pil_image.info.get("invokeai_workflow", None)
|
||||
|
||||
# The fallback value for workflow is None.
|
||||
stringified_workflow: str | None = None
|
||||
|
||||
# If the workflow is present in the image file, we will attempt to parse it as JSON. When we create images, we
|
||||
# always store workflows as a stringified JSON WorkflowWithoutID. So, we expect it to be a string here.
|
||||
if isinstance(workflow_raw, str):
|
||||
try:
|
||||
# Validate the workflow JSON before storing it
|
||||
WorkflowWithoutIDValidator.validate_json(workflow_raw)
|
||||
# Looks good, overwrite the fallback value
|
||||
stringified_workflow = workflow_raw
|
||||
except Exception:
|
||||
logger.debug("Failed to parse workflow for uploaded image")
|
||||
pass
|
||||
|
||||
# We expect the workflow, if embedded in the image, to be a JSON-stringified Graph. We will store it as a
|
||||
# string.
|
||||
graph_raw: str | None = invokeai_graph_override or pil_image.info.get("invokeai_graph", None)
|
||||
|
||||
# The fallback value for graph is None.
|
||||
stringified_graph: str | None = None
|
||||
|
||||
# If the graph is present in the image file, we will attempt to parse it as JSON. When we create images, we
|
||||
# always store graphs as a stringified JSON Graph. So, we expect it to be a string here.
|
||||
if isinstance(graph_raw, str):
|
||||
try:
|
||||
# TODO(psyche): Due to pydantic's handling of None values, it is possible for the graph to fail validation,
|
||||
# even if it is a direct dump of a valid graph. Node fields in the graph are allowed to have be unset if
|
||||
# they have incoming connections, but something about the ser/de process cannot adequately handle this.
|
||||
#
|
||||
# In lieu of fixing the graph validation, we will just do a simple check here to see if the graph is dict
|
||||
# with the correct keys. This is not a perfect solution, but it should be good enough for now.
|
||||
|
||||
# FIX ME: Validate the graph JSON before storing it
|
||||
# Graph.model_validate_json(graph_raw)
|
||||
|
||||
# Crappy workaround to validate JSON
|
||||
graph_parsed = json.loads(graph_raw)
|
||||
if not isinstance(graph_parsed, dict):
|
||||
raise ValueError("Not a dict")
|
||||
if not isinstance(graph_parsed.get("nodes", None), dict):
|
||||
raise ValueError("'nodes' is not a dict")
|
||||
if not isinstance(graph_parsed.get("edges", None), list):
|
||||
raise ValueError("'edges' is not a list")
|
||||
|
||||
# Looks good, overwrite the fallback value
|
||||
stringified_graph = graph_raw
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse graph for uploaded image, {e}")
|
||||
pass
|
||||
|
||||
return ExtractedMetadata(
|
||||
invokeai_metadata=stringified_metadata, invokeai_workflow=stringified_workflow, invokeai_graph=stringified_graph
|
||||
)
|
||||
@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.util.logging import logging
|
||||
@@ -99,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
||||
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
async def get_config_() -> AppConfig:
|
||||
infill_methods = ["lama", "tile", "cv2", "color"] # TODO: add mosaic back
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
@@ -121,6 +122,21 @@ async def get_config() -> AppConfig:
|
||||
)
|
||||
|
||||
|
||||
class InvokeAIAppConfigWithSetFields(BaseModel):
|
||||
"""InvokeAI App Config with model fields set"""
|
||||
|
||||
set_fields: set[str] = Field(description="The set fields")
|
||||
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
|
||||
)
|
||||
async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
|
||||
config = get_config()
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
@@ -146,7 +146,7 @@ async def list_boards(
|
||||
response_model=list[str],
|
||||
)
|
||||
async def list_all_board_image_names(
|
||||
board_id: str | Literal["none"] = Path(description="The id of the board"),
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
|
||||
) -> list[str]:
|
||||
|
||||
@@ -6,9 +6,10 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
@@ -45,18 +46,16 @@ async def upload_image(
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
metadata: Optional[JsonValue] = Body(
|
||||
default=None, description="The metadata to associate with the image", embed=True
|
||||
metadata: Optional[str] = Body(
|
||||
default=None,
|
||||
description="The metadata to associate with the image, must be a stringified JSON dict",
|
||||
embed=True,
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
_metadata = None
|
||||
_workflow = None
|
||||
_graph = None
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
@@ -67,30 +66,13 @@ async def upload_image(
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
# TODO: retain non-invokeai metadata on upload?
|
||||
# attempt to parse metadata from image
|
||||
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
|
||||
if isinstance(metadata_raw, str):
|
||||
_metadata = metadata_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to parse workflow from image
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if isinstance(workflow_raw, str):
|
||||
_workflow = workflow_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to extract graph from image
|
||||
graph_raw = pil_image.info.get("invokeai_graph", None)
|
||||
if isinstance(graph_raw, str):
|
||||
_graph = graph_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
|
||||
pass
|
||||
extracted_metadata = extract_metadata_from_image(
|
||||
pil_image=pil_image,
|
||||
invokeai_metadata_override=metadata,
|
||||
invokeai_workflow_override=None,
|
||||
invokeai_graph_override=None,
|
||||
logger=ApiDependencies.invoker.services.logger,
|
||||
)
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
@@ -99,9 +81,9 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
board_id=board_id,
|
||||
metadata=_metadata,
|
||||
workflow=_workflow,
|
||||
graph=_graph,
|
||||
metadata=extracted_metadata.invokeai_metadata,
|
||||
workflow=extracted_metadata.invokeai_workflow,
|
||||
graph=extracted_metadata.invokeai_graph,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
@@ -114,6 +96,22 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
class ImageUploadEntry(BaseModel):
|
||||
image_dto: ImageDTO = Body(description="The image DTO")
|
||||
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
|
||||
|
||||
|
||||
@images_router.post("/", operation_id="create_image_upload_entry")
|
||||
async def create_image_upload_entry(
|
||||
width: int = Body(description="The width of the image"),
|
||||
height: int = Body(description="The height of the image"),
|
||||
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
|
||||
) -> ImageUploadEntry:
|
||||
"""Uploads an image from a URL, not implemented"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
|
||||
@@ -28,12 +28,10 @@ from invokeai.app.services.model_records import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
@@ -48,7 +48,9 @@ async def enqueue_batch(
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
@@ -10,11 +14,14 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordListItemWithThumbnailDTO,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowRecordWithThumbnailDTO,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import WorkflowThumbnailFileNotFoundException
|
||||
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
|
||||
@@ -22,15 +29,17 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
"/i/{workflow_id}",
|
||||
operation_id="get_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
200: {"model": WorkflowRecordWithThumbnailDTO},
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowRecordDTO:
|
||||
) -> WorkflowRecordWithThumbnailDTO:
|
||||
"""Gets a workflow"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
@@ -57,6 +66,11 @@ async def delete_workflow(
|
||||
workflow_id: str = Path(description="The workflow to delete"),
|
||||
) -> None:
|
||||
"""Deletes a workflow"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except WorkflowThumbnailFileNotFoundException:
|
||||
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
|
||||
pass
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
|
||||
|
||||
|
||||
@@ -78,7 +92,7 @@ async def create_workflow(
|
||||
"/",
|
||||
operation_id="list_workflows",
|
||||
responses={
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]},
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
@@ -88,10 +102,158 @@ async def list_workflows(
|
||||
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
|
||||
),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories of workflow to get"),
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by, direction=direction, page=page, per_page=per_page, query=query, category=category
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by,
|
||||
direction=direction,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
query=query,
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
WorkflowRecordListItemWithThumbnailDTO(
|
||||
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
|
||||
**workflow.model_dump(),
|
||||
)
|
||||
)
|
||||
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
|
||||
items=workflows_with_thumbnails,
|
||||
total=workflows.total,
|
||||
page=workflows.page,
|
||||
pages=workflows.pages,
|
||||
per_page=workflows.per_page,
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.put(
|
||||
"/i/{workflow_id}/thumbnail",
|
||||
operation_id="set_workflow_thumbnail",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def set_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
image: UploadFile = File(description="The image file to upload"),
|
||||
):
|
||||
"""Sets a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.save(workflow_id, pil_image)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@workflows_router.delete(
|
||||
"/i/{workflow_id}/thumbnail",
|
||||
operation_id="delete_workflow_thumbnail",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def delete_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
):
|
||||
"""Removes a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/i/{workflow_id}/thumbnail",
|
||||
operation_id="get_workflow_thumbnail",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The workflow thumbnail was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The workflow thumbnail could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a workflow's thumbnail image"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=workflow_id + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
|
||||
async def get_counts_by_tag(
|
||||
tags: list[str] = Query(description="The tags to get counts for"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by tag"""
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
|
||||
tags=tags, categories=categories, has_been_opened=has_been_opened
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
|
||||
async def counts_by_category(
|
||||
categories: list[WorkflowCategory] = Query(description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by category"""
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
|
||||
categories=categories, has_been_opened=has_been_opened
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.put(
|
||||
"/i/{workflow_id}/opened_at",
|
||||
operation_id="update_opened_at",
|
||||
)
|
||||
async def update_opened_at(
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
) -> None:
|
||||
"""Updates the opened_at field of a workflow"""
|
||||
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
@@ -15,11 +11,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
@@ -38,31 +30,13 @@ from invokeai.app.api.routers import (
|
||||
from invokeai.app.api.sockets import SocketIO
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
|
||||
# the correct port when the server starts in the lifespan handler.
|
||||
port = app_config.port
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -71,7 +45,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
|
||||
proto = "https" if app_config.ssl_certfile else "http"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
|
||||
|
||||
# Logging this way ignores the logger's log level and _always_ logs the message
|
||||
record = logger.makeRecord(
|
||||
@@ -186,73 +160,3 @@ except RuntimeError:
|
||||
app.mount(
|
||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||
) # docs favicon is in here
|
||||
|
||||
|
||||
def check_cudnn(logger: logging.Logger) -> None:
|
||||
"""Check for cuDNN issues that could be causing degraded performance."""
|
||||
if torch.backends.cudnn.is_available():
|
||||
try:
|
||||
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
|
||||
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
logger.info(f"cuDNN version: {cudnn_version}")
|
||||
except RuntimeError as e:
|
||||
logger.warning(
|
||||
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
|
||||
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
|
||||
f"system. Full error message:\n{e}"
|
||||
)
|
||||
|
||||
|
||||
def invoke_api() -> None:
|
||||
def find_port(port: int) -> int:
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
# https://github.com/WaylonWalker
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
if s.connect_ex(("localhost", port)) == 0:
|
||||
return find_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
global port
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -1,33 +1,5 @@
|
||||
import shutil
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
custom_nodes_path = Path(get_config().custom_nodes_path)
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||
custom_nodes_readme_path = str(custom_nodes_path / "README.md")
|
||||
|
||||
# copy our custom nodes __init__.py to the custom nodes directory
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
|
||||
|
||||
# set the same permissions as the destination directory, in case our source is read-only,
|
||||
# so that the files are user-writable
|
||||
for p in custom_nodes_path.glob("**/*"):
|
||||
p.chmod(custom_nodes_path.stat().st_mode)
|
||||
|
||||
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError(f"Could not load custom nodes from {custom_nodes_init_path}")
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# add core nodes to __all__
|
||||
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
|
||||
__all__ = [f.stem for f in python_files] # type: ignore
|
||||
|
||||
@@ -8,6 +8,7 @@ import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -27,7 +28,6 @@ import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
@@ -44,8 +44,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
@@ -102,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
@@ -175,66 +142,11 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def invalidate_typeadapter(cls) -> None:
|
||||
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
|
||||
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
|
||||
the updated allowlist and denylist."""
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||
@@ -337,6 +249,105 @@ class BaseInvocation(ABC, BaseModel):
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
class InvocationRegistry:
|
||||
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
|
||||
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation types.
|
||||
|
||||
This is used to parse serialized invocations into the correct invocation class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_invocation_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation type adapter."""
|
||||
cls.get_invocation_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[type[BaseInvocation]] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in cls.get_invocation_classes()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in cls.get_invocation_classes())
|
||||
|
||||
@classmethod
|
||||
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
|
||||
"""Gets the invocation class for a given invocation type."""
|
||||
return cls.get_invocations_map().get(invocation_type)
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@classmethod
|
||||
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
|
||||
|
||||
This is used to parse serialized invocation outputs into the correct invocation output class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_output_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation output type adapter."""
|
||||
cls.get_output_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in cls.get_output_classes())
|
||||
|
||||
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
@@ -414,7 +425,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
|
||||
|
||||
ui_type = field.json_schema_extra.get("ui_type", None)
|
||||
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
|
||||
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
|
||||
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
return None
|
||||
|
||||
@@ -446,8 +457,27 @@ def invocation(
|
||||
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
raise ValueError(f'Invocation type "{invocation_type}" already exists')
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
node_pack = cls.__module__.split(".")[0]
|
||||
|
||||
# Handle the case where an existing node is being clobbered by the one we are registering
|
||||
if invocation_type in InvocationRegistry.get_invocation_types():
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
assert clobbered_invocation is not None
|
||||
|
||||
clobbered_node_pack = clobbered_invocation.UIConfig.node_pack
|
||||
|
||||
if clobbered_node_pack == "invokeai":
|
||||
# The node being clobbered is a core node
|
||||
raise ValueError(
|
||||
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a core node with the same type already exists'
|
||||
)
|
||||
else:
|
||||
# The node being clobbered is a custom node
|
||||
raise ValueError(
|
||||
f'Cannot load node "{invocation_type}" from node pack "{node_pack}" - a node with the same type already exists in node pack "{clobbered_node_pack}"'
|
||||
)
|
||||
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
@@ -457,8 +487,7 @@ def invocation(
|
||||
uiconfig["tags"] = tags
|
||||
uiconfig["category"] = category
|
||||
uiconfig["classification"] = classification
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||
uiconfig["node_pack"] = node_pack
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
@@ -518,8 +547,7 @@ def invocation(
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
InvocationRegistry.register_invocation(cls)
|
||||
|
||||
return cls
|
||||
|
||||
@@ -544,7 +572,7 @@ def invocation_output(
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
if output_type in InvocationRegistry.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
@@ -565,7 +593,7 @@ def invocation_output(
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
InvocationRegistry.register_output(cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
@@ -40,10 +40,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"compel",
|
||||
title="Prompt",
|
||||
title="Prompt - SD1.5",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -233,10 +233,10 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
@invocation(
|
||||
"sdxl_compel_prompt",
|
||||
title="SDXL Prompt",
|
||||
title="Prompt - SDXL",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -327,10 +327,10 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
|
||||
@invocation(
|
||||
"sdxl_refiner_compel_prompt",
|
||||
title="SDXL Refiner Prompt",
|
||||
title="Prompt - SDXL Refiner",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.1.2",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -376,10 +376,10 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"clip_skip",
|
||||
title="CLIP Skip",
|
||||
title="Apply CLIP Skip - SD1.5, SDXL",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
@@ -513,7 +513,7 @@ def log_tokenization_for_text(
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"\n>> [TOKENLOG] Tokens {display_label or ''} ({usedTokens}):")
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
|
||||
@@ -87,7 +87,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
|
||||
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.model import UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import MainConfigBase
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""
|
||||
Invoke-managed custom node loader. See README.md for more information.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
loaded_count = 0
|
||||
|
||||
|
||||
for d in Path(__file__).parent.iterdir():
|
||||
# skip files
|
||||
if not d.is_dir():
|
||||
continue
|
||||
|
||||
# skip hidden directories
|
||||
if d.name.startswith("_") or d.name.startswith("."):
|
||||
continue
|
||||
|
||||
# skip directories without an `__init__.py`
|
||||
init = d / "__init__.py"
|
||||
if not init.exists():
|
||||
continue
|
||||
|
||||
module_name = init.parent.stem
|
||||
|
||||
# skip if already imported
|
||||
if module_name in globals():
|
||||
continue
|
||||
|
||||
# load the module, appending adding a suffix to identify it as a custom node pack
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warn(f"Could not load {init}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
try:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_count += 1
|
||||
except Exception:
|
||||
full_error = traceback.format_exc()
|
||||
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
|
||||
|
||||
del init, module_name
|
||||
|
||||
if loaded_count > 0:
|
||||
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")
|
||||
@@ -39,8 +39,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -127,10 +127,10 @@ def get_scheduler(
|
||||
|
||||
@invocation(
|
||||
"denoise_latents",
|
||||
title="Denoise Latents",
|
||||
title="Denoise - SD1.5, SDXL",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.5.3",
|
||||
version="1.5.4",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
@@ -57,6 +57,9 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
SigLipModel = "SigLipModelField"
|
||||
FluxReduxModel = "FluxReduxModelField"
|
||||
LlavaOnevisionModel = "LLaVAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -152,6 +155,7 @@ class FieldDescriptions:
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
spandrel_image_to_image_model = "Image-to-Image model"
|
||||
vllm_model = "VLLM model"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
@@ -201,6 +205,9 @@ class FieldDescriptions:
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
|
||||
flux_redux_conditioning = "FLUX Redux conditioning tensor"
|
||||
vllm_model = "The VLLM model to use"
|
||||
flux_fill_conditioning = "FLUX Fill conditioning tensor"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
@@ -259,6 +266,24 @@ class FluxConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class FluxReduxConditioningField(BaseModel):
|
||||
"""A FLUX Redux conditioning tensor primitive value"""
|
||||
|
||||
conditioning: TensorField = Field(description="The Redux image conditioning tensor.")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class FluxFillConditioningField(BaseModel):
|
||||
"""A FLUX Fill conditioning field."""
|
||||
|
||||
image: ImageField = Field(description="The FLUX Fill reference image.")
|
||||
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -21,11 +20,10 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_control_lora_loader",
|
||||
title="Flux Control LoRA",
|
||||
title="Control LoRA - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.1.1",
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
@@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -52,7 +51,6 @@ class FluxControlNetOutput(BaseInvocationOutput):
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
"""Collect FLUX ControlNet info to pass to other nodes."""
|
||||
|
||||
@@ -10,11 +10,13 @@ from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxFillConditioningField,
|
||||
FluxReduxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
@@ -46,8 +48,8 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -61,8 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
version="3.3.0",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
@@ -103,6 +104,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
redux_conditioning: FluxReduxConditioningField | list[FluxReduxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Redux conditioning tensor.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
fill_conditioning: FluxFillConditioningField | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Fill conditioning.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
cfg_scale_start_step: int = InputField(
|
||||
default=0,
|
||||
@@ -190,11 +201,23 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
redux_conditionings: list[FluxReduxConditioning] = self._load_redux_conditioning(
|
||||
context=context,
|
||||
redux_cond_field=self.redux_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
text_conditioning=pos_text_conditionings,
|
||||
redux_conditioning=redux_conditionings,
|
||||
img_seq_len=packed_h * packed_w,
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
RegionalPromptingExtension.from_text_conditioning(
|
||||
text_conditioning=neg_text_conditionings, redux_conditioning=[], img_seq_len=packed_h * packed_w
|
||||
)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
@@ -243,8 +266,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if is_schnell and self.control_lora:
|
||||
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
||||
|
||||
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
|
||||
img_cond: torch.Tensor | None = None
|
||||
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
|
||||
if is_flux_fill:
|
||||
img_cond = self._prep_flux_fill_img_cond(
|
||||
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
|
||||
)
|
||||
else:
|
||||
if self.fill_conditioning is not None:
|
||||
raise ValueError("fill_conditioning was provided, but the model is not a FLUX Fill model.")
|
||||
|
||||
if self.control_lora is not None:
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
@@ -253,7 +287,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
img_cond = pack(img_cond) if img_cond is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
@@ -400,6 +433,42 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return text_conditionings
|
||||
|
||||
def _load_redux_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
redux_cond_field: FluxReduxConditioningField | list[FluxReduxConditioningField] | None,
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> list[FluxReduxConditioning]:
|
||||
# Normalize to a list of FluxReduxConditioningFields.
|
||||
if redux_cond_field is None:
|
||||
return []
|
||||
|
||||
redux_cond_list = (
|
||||
[redux_cond_field] if isinstance(redux_cond_field, FluxReduxConditioningField) else redux_cond_field
|
||||
)
|
||||
|
||||
redux_conditionings: list[FluxReduxConditioning] = []
|
||||
for redux_cond_field in redux_cond_list:
|
||||
# Load the Redux conditioning tensor.
|
||||
redux_cond_data = context.tensors.load(redux_cond_field.conditioning.tensor_name)
|
||||
redux_cond_data.to(device=device, dtype=dtype)
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if redux_cond_field.mask is not None:
|
||||
mask = context.tensors.load(redux_cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
redux_conditionings.append(FluxReduxConditioning(redux_embeddings=redux_cond_data, mask=mask))
|
||||
|
||||
return redux_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
@@ -610,7 +679,70 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
return pack(img_cond)
|
||||
|
||||
def _prep_flux_fill_img_cond(
|
||||
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Prepare the FLUX Fill conditioning. This method should be called iff the model is a FLUX Fill model.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
|
||||
"""
|
||||
# Validate inputs.
|
||||
if self.fill_conditioning is None:
|
||||
raise ValueError("A FLUX Fill model is being used without fill_conditioning.")
|
||||
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A FLUX Fill model is being used without controlnet_vae.")
|
||||
if self.control_lora is not None:
|
||||
raise ValueError(
|
||||
"A FLUX Fill model is being used, but a control_lora was provided. Control LoRAs are not compatible with FLUX Fill models."
|
||||
)
|
||||
|
||||
# Log input warnings related to FLUX Fill usage.
|
||||
if self.denoise_mask is not None:
|
||||
context.logger.warning(
|
||||
"Both fill_conditioning and a denoise_mask were provided. You probably meant to use one or the other."
|
||||
)
|
||||
if self.guidance < 25.0:
|
||||
context.logger.warning("A guidance value of ~30.0 is recommended for FLUX Fill models.")
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
cond_img = np.array(cond_img)
|
||||
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
||||
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
|
||||
cond_img = cond_img.to(device=device, dtype=dtype)
|
||||
|
||||
# Load the mask and resize it to the target image size.
|
||||
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
|
||||
# We expect mask to be a bool tensor with shape [1, H, W].
|
||||
assert mask.dtype == torch.bool
|
||||
assert mask.dim() == 3
|
||||
assert mask.shape[0] == 1
|
||||
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
|
||||
|
||||
# Prepare image conditioning.
|
||||
cond_img = cond_img * (1 - mask)
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
|
||||
cond_img = pack(cond_img)
|
||||
|
||||
# Prepare mask conditioning.
|
||||
mask = mask[:, 0, :, :]
|
||||
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
|
||||
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
|
||||
mask = pack(mask)
|
||||
|
||||
# Merge image and mask conditioning.
|
||||
img_cond = torch.cat((cond_img, mask), dim=-1)
|
||||
return img_cond
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
|
||||
46
invokeai/app/invocations/flux_fill.py
Normal file
46
invokeai/app/invocations/flux_fill.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxFillConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_fill_output")
|
||||
class FluxFillOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Fill invocation."""
|
||||
|
||||
fill_cond: FluxFillConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_fill",
|
||||
title="FLUX Fill Conditioning",
|
||||
tags=["inpaint"],
|
||||
category="inpaint",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class FluxFillInvocation(BaseInvocation):
|
||||
"""Prepare the FLUX Fill conditioning data."""
|
||||
|
||||
image: ImageField = InputField(description="The FLUX Fill reference image.")
|
||||
mask: TensorField = InputField(
|
||||
description="The bool inpainting mask. Excluded regions should be set to "
|
||||
"False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxFillOutput:
|
||||
return FluxFillOutput(fill_cond=FluxFillConditioningField(image=self.image, mask=self.mask))
|
||||
@@ -4,7 +4,7 @@ from typing import List, Literal, Union
|
||||
from pydantic import field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField, UIType
|
||||
from invokeai.app.invocations.ip_adapter import (
|
||||
CLIP_VISION_MODEL_MAP,
|
||||
@@ -28,7 +28,6 @@ from invokeai.backend.model_manager.config import (
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxIPAdapterInvocation(BaseInvocation):
|
||||
"""Collects FLUX IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -3,14 +3,13 @@ from typing import Optional
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
@@ -28,11 +27,10 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
title="Apply LoRA - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.2.1",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
@@ -107,11 +105,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"flux_lora_collection_loader",
|
||||
title="FLUX LoRA Collection Loader",
|
||||
title="Apply LoRA Collection - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.3.1",
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to a FLUX transformer."""
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Literal
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -17,8 +16,8 @@ from invokeai.app.util.t5_model_identifier import (
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
@@ -37,11 +36,10 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
title="Main Model - FLUX",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.5",
|
||||
classification=Classification.Prototype,
|
||||
version="1.0.6",
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
120
invokeai/app/invocations/flux_redux.py
Normal file
120
invokeai/app/invocations/flux_redux.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxReduxConditioningField,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.starter_models import siglip
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation_output("flux_redux_output")
|
||||
class FluxReduxOutput(BaseInvocationOutput):
|
||||
"""The conditioning output of a FLUX Redux invocation."""
|
||||
|
||||
redux_cond: FluxReduxConditioningField = OutputField(
|
||||
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_redux",
|
||||
title="FLUX Redux",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="2.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class FluxReduxInvocation(BaseInvocation):
|
||||
"""Runs a FLUX Redux model to generate a conditioning tensor."""
|
||||
|
||||
image: ImageField = InputField(description="The FLUX Redux image prompt.")
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None,
|
||||
description="The bool mask associated with this FLUX Redux image prompt. Excluded regions should be set to "
|
||||
"False, included regions should be set to True.",
|
||||
)
|
||||
redux_model: ModelIdentifierField = InputField(
|
||||
description="The FLUX Redux model to use.",
|
||||
title="FLUX Redux Model",
|
||||
ui_type=UIType.FluxReduxModel,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxReduxOutput:
|
||||
image = context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
encoded_x = self._siglip_encode(context, image)
|
||||
redux_conditioning = self._flux_redux_encode(context, encoded_x)
|
||||
|
||||
tensor_name = context.tensors.save(redux_conditioning)
|
||||
return FluxReduxOutput(
|
||||
redux_cond=FluxReduxConditioningField(conditioning=TensorField(tensor_name=tensor_name), mask=self.mask)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
|
||||
siglip_model_config = self._get_siglip_model(context)
|
||||
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
|
||||
assert isinstance(siglip_pipeline, SigLipPipeline)
|
||||
return siglip_pipeline.encode_image(
|
||||
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor) -> torch.Tensor:
|
||||
with context.models.load(self.redux_model).model_on_device() as (_, flux_redux):
|
||||
assert isinstance(flux_redux, FluxReduxModel)
|
||||
dtype = next(flux_redux.parameters()).dtype
|
||||
encoded_x = encoded_x.to(dtype=dtype)
|
||||
return flux_redux(encoded_x)
|
||||
|
||||
def _get_siglip_model(self, context: InvocationContext) -> AnyModelConfig:
|
||||
siglip_models = context.models.search_by_attrs(name=siglip.name, base=BaseModelType.Any, type=ModelType.SigLIP)
|
||||
|
||||
if not len(siglip_models) > 0:
|
||||
context.logger.warning(
|
||||
f"The SigLIP model required by FLUX Redux ({siglip.name}) is not installed. Downloading and installing now. This may take a while."
|
||||
)
|
||||
|
||||
# TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
|
||||
config_overrides = ModelRecordChanges(name=siglip.name, type=ModelType.SigLIP)
|
||||
|
||||
# Queue the job
|
||||
job = context._services.model_manager.install.heuristic_import(siglip.source, config=config_overrides)
|
||||
|
||||
# Wait for up to 10 minutes - model is ~3.5GB
|
||||
context._services.model_manager.install.wait_for_job(job, timeout=600)
|
||||
|
||||
siglip_models = context.models.search_by_attrs(
|
||||
name=siglip.name,
|
||||
base=BaseModelType.Any,
|
||||
type=ModelType.SigLIP,
|
||||
)
|
||||
|
||||
if len(siglip_models) == 0:
|
||||
context.logger.error("Error while fetching SigLIP for FLUX Redux")
|
||||
assert len(siglip_models) == 1
|
||||
|
||||
return siglip_models[0]
|
||||
@@ -4,7 +4,7 @@ from typing import Iterator, Literal, Optional, Tuple
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.model_manager import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -26,11 +26,10 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
title="Prompt - FLUX",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
version="1.1.2",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
|
||||
@@ -22,10 +22,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="FLUX Latents to Image",
|
||||
title="Latents to Image - FLUX",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.2",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -41,16 +41,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1090 # Determined experimentally.
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
|
||||
@@ -19,10 +19,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="FLUX Image to Latents",
|
||||
title="Image to Latents - FLUX",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("ideal_size_output")
|
||||
@@ -19,9 +19,9 @@ class IdealSizeOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"ideal_size",
|
||||
title="Ideal Size",
|
||||
title="Ideal Size - SD1.5, SDXL",
|
||||
tags=["latents", "math", "ideal_size"],
|
||||
version="1.0.4",
|
||||
version="1.0.5",
|
||||
)
|
||||
class IdealSizeInvocation(BaseInvocation):
|
||||
"""Calculates the ideal size for generation to avoid duplication"""
|
||||
|
||||
@@ -355,7 +355,6 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["image", "unsharp_mask"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies an unsharp mask filter to an image"""
|
||||
@@ -1051,7 +1050,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Internal,
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Handles Canvas V2 image output masking and cropping"""
|
||||
@@ -1089,6 +1088,131 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
|
||||
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
|
||||
If the fade size is 0, the mask is returned as-is.
|
||||
"""
|
||||
|
||||
mask: ImageField = InputField(description="The mask to expand")
|
||||
threshold: int = InputField(default=0, ge=0, le=255, description="The threshold for the binary mask (0-255)")
|
||||
fade_size_px: int = InputField(default=32, ge=0, description="The size of the fade in pixels")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.fade_size_px == 0:
|
||||
# If the fade size is 0, just return the mask as-is.
|
||||
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
np_mask = numpy.array(pil_mask)
|
||||
|
||||
# Threshold the mask to create a binary mask - 0 for black, 255 for white
|
||||
# If we don't threshold we can get some weird artifacts
|
||||
np_mask = numpy.where(np_mask > self.threshold, 255, 0).astype(numpy.uint8)
|
||||
|
||||
# Create a mask for the black region (1 where black, 0 otherwise)
|
||||
black_mask = (np_mask == 0).astype(numpy.uint8)
|
||||
|
||||
# Invert the black region
|
||||
bg_mask = 1 - black_mask
|
||||
|
||||
# Create a distance transform of the inverted mask
|
||||
dist = cv2.distanceTransform(bg_mask, cv2.DIST_L2, 5)
|
||||
|
||||
# Normalize distances so that pixels <fade_size_px become a linear gradient (0 to 1)
|
||||
d_norm = numpy.clip(dist / self.fade_size_px, 0, 1)
|
||||
|
||||
# Control points: x values (normalized distance) and corresponding fade pct y values.
|
||||
|
||||
# There are some magic numbers here that are used to create a smooth transition:
|
||||
# - The first point is at 0% of fade size from edge of mask (meaning the edge of the mask), and is 0% fade (black)
|
||||
# - The second point is 1px from the edge of the mask and also has 0% fade, effectively expanding the mask
|
||||
# by 1px. This fixes an issue where artifacts can occur at the edge of the mask
|
||||
# - The third point is at 20% of the fade size from the edge of the mask and has 20% fade
|
||||
# - The fourth point is at 80% of the fade size from the edge of the mask and has 90% fade
|
||||
# - The last point is at 100% of the fade size from the edge of the mask and has 100% fade (white)
|
||||
|
||||
# x values: 0 = mask edge, 1 = fade_size_px from edge
|
||||
x_control = numpy.array([0.0, 1.0 / self.fade_size_px, 0.2, 0.8, 1.0])
|
||||
# y values: 0 = black, 1 = white
|
||||
y_control = numpy.array([0.0, 0.0, 0.2, 0.9, 1.0])
|
||||
|
||||
# Fit a cubic polynomial that smoothly passes through the control points
|
||||
coeffs = numpy.polyfit(x_control, y_control, 3)
|
||||
poly = numpy.poly1d(coeffs)
|
||||
|
||||
# Evaluate the polynomial
|
||||
feather = poly(d_norm)
|
||||
|
||||
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
|
||||
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
|
||||
# polynomial fit, which is a best approximation of the control points but not an exact match.
|
||||
|
||||
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
|
||||
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
|
||||
|
||||
# Force pixels at or beyond the fade distance to exactly 1.0
|
||||
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
|
||||
|
||||
# Clip any other values to ensure they're in the valid range [0,1]
|
||||
feather = numpy.clip(feather, 0, 1)
|
||||
|
||||
# Build final image.
|
||||
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
|
||||
|
||||
# Convert back to PIL, grayscale
|
||||
pil_result = Image.fromarray(np_result.astype(numpy.uint8), mode="L")
|
||||
|
||||
image_dto = context.images.save(image=pil_result, image_category=ImageCategory.MASK)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"apply_mask_to_image",
|
||||
title="Apply Mask to Image",
|
||||
tags=["image", "mask", "blend"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""
|
||||
Extracts a region from a generated image using a mask and blends it seamlessly onto a source image.
|
||||
The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="The image from which to extract the masked region")
|
||||
mask: ImageField = InputField(description="The mask defining the region (black=keep, white=discard)")
|
||||
invert_mask: bool = InputField(
|
||||
default=False,
|
||||
description="Whether to invert the mask before applying it",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load images
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.invert_mask:
|
||||
# Invert the mask if requested
|
||||
mask = ImageOps.invert(mask.copy())
|
||||
|
||||
# Combine the mask as the alpha channel of the image
|
||||
r, g, b, _ = image.split() # Split the image into RGB and alpha channels
|
||||
result_image = Image.merge("RGBA", (r, g, b, mask)) # Use the mask as the new alpha channel
|
||||
|
||||
# Save the resulting image
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_noise",
|
||||
title="Add Image Noise",
|
||||
@@ -1159,7 +1283,6 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
tags=["image", "crop"],
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
|
||||
@@ -1186,7 +1309,6 @@ class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
tags=["image", "crop"],
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Paste the source image into the target image at the given bounding box.
|
||||
|
||||
@@ -31,10 +31,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"i2l",
|
||||
title="Image to Latents",
|
||||
title="Image to Latents - SD1.5, SDXL",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
@@ -13,10 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
StarterModel,
|
||||
@@ -24,6 +22,7 @@ from invokeai.backend.model_manager.starter_models import (
|
||||
ip_adapter_sd_image_encoder,
|
||||
ip_adapter_sdxl_image_encoder,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
@@ -69,7 +68,13 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] =
|
||||
}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.5.0")
|
||||
@invocation(
|
||||
"ip_adapter",
|
||||
title="IP-Adapter - SD1.5, SDXL",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="1.5.1",
|
||||
)
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"l2i",
|
||||
title="Latents to Image",
|
||||
title="Latents to Image - SD1.5, SDXL",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.3.1",
|
||||
version="1.3.2",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -60,7 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 960 # Determined experimentally.
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
@@ -84,9 +84,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
# We add 20% to the working memory estimate to be safe.
|
||||
working_memory = int(working_memory * 1.2)
|
||||
return working_memory
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
67
invokeai/app/invocations/llava_onevision_vllm.py
Normal file
67
invokeai/app/invocations/llava_onevision_vllm.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"llava_onevision_vllm",
|
||||
title="LLaVA OneVision VLLM",
|
||||
tags=["vllm"],
|
||||
category="vllm",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
"""Run a LLaVA OneVision VLLM model."""
|
||||
|
||||
images: list[ImageField] | ImageField | None = InputField(default=None, max_length=3, description="Input image.")
|
||||
prompt: str = InputField(
|
||||
default="",
|
||||
description="Input text prompt.",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
vllm_model: ModelIdentifierField = InputField(
|
||||
title="LLaVA Model Type",
|
||||
description=FieldDescriptions.vllm_model,
|
||||
ui_type=UIType.LlavaOnevisionModel,
|
||||
)
|
||||
|
||||
@field_validator("images", mode="before")
|
||||
def listify_images(cls, v: Any) -> list:
|
||||
if v is None:
|
||||
return v
|
||||
if not isinstance(v, list):
|
||||
return [v]
|
||||
return v
|
||||
|
||||
def _get_images(self, context: InvocationContext) -> list[Image]:
|
||||
if self.images is None:
|
||||
return []
|
||||
|
||||
image_fields = self.images if isinstance(self.images, list) else [self.images]
|
||||
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
images = self._get_images(context)
|
||||
|
||||
with context.models.load(self.vllm_model) as vllm_model:
|
||||
assert isinstance(vllm_model, LlavaOnevisionModel)
|
||||
output = vllm_model.run(
|
||||
prompt=self.prompt,
|
||||
images=images,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
)
|
||||
|
||||
return StringOutput(value=output)
|
||||
83
invokeai/app/invocations/load_custom_nodes.py
Normal file
83
invokeai/app/invocations/load_custom_nodes.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_custom_nodes(custom_nodes_path: Path, logger: logging.Logger):
|
||||
"""
|
||||
Loads all custom nodes from the custom_nodes_path directory.
|
||||
|
||||
If custom_nodes_path does not exist, it creates it.
|
||||
|
||||
It also copies the custom_nodes/README.md file to the custom_nodes_path directory. Because this file may change,
|
||||
it is _always_ copied to the custom_nodes_path directory.
|
||||
|
||||
Then, it crawls the custom_nodes_path directory and imports all top-level directories as python modules.
|
||||
|
||||
If the directory does not contain an __init__.py file or starts with an `_` or `.`, it is skipped.
|
||||
"""
|
||||
|
||||
# create the custom nodes directory if it does not exist
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the README file to the custom nodes directory
|
||||
source_custom_nodes_readme_path = Path(__file__).parent / "custom_nodes/README.md"
|
||||
target_custom_nodes_readme_path = Path(custom_nodes_path) / "README.md"
|
||||
|
||||
# copy our custom nodes README to the custom nodes directory
|
||||
shutil.copy(source_custom_nodes_readme_path, target_custom_nodes_readme_path)
|
||||
|
||||
loaded_packs: list[str] = []
|
||||
failed_packs: list[str] = []
|
||||
|
||||
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||
for d in custom_nodes_path.iterdir():
|
||||
# skip files
|
||||
if not d.is_dir():
|
||||
continue
|
||||
|
||||
# skip hidden directories
|
||||
if d.name.startswith("_") or d.name.startswith("."):
|
||||
continue
|
||||
|
||||
# skip directories without an `__init__.py`
|
||||
init = d / "__init__.py"
|
||||
if not init.exists():
|
||||
continue
|
||||
|
||||
module_name = init.parent.stem
|
||||
|
||||
# skip if already imported
|
||||
if module_name in globals():
|
||||
continue
|
||||
|
||||
# load the module
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warning(f"Could not load {init}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
try:
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
loaded_packs.append(module_name)
|
||||
except Exception:
|
||||
failed_packs.append(module_name)
|
||||
full_error = traceback.format_exc()
|
||||
logger.error(f"Failed to load node pack {module_name} (may have partially loaded):\n{full_error}")
|
||||
|
||||
del init, module_name
|
||||
|
||||
loaded_count = len(loaded_packs)
|
||||
if loaded_count > 0:
|
||||
logger.info(
|
||||
f"Loaded {loaded_count} node pack{'s' if loaded_count != 1 else ''} from {custom_nodes_path}: {', '.join(loaded_packs)}"
|
||||
)
|
||||
@@ -4,7 +4,6 @@ from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
@@ -58,7 +57,6 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
|
||||
@@ -67,7 +65,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
|
||||
if self.invert:
|
||||
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
|
||||
@@ -87,7 +85,6 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class InvertTensorMaskInvocation(BaseInvocation):
|
||||
"""Inverts a tensor mask."""
|
||||
@@ -234,7 +231,6 @@ WHITE = ColorField(r=255, g=255, b=255, a=255)
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class GetMaskBoundingBoxInvocation(BaseInvocation):
|
||||
"""Gets the bounding box of the given mask image."""
|
||||
|
||||
@@ -284,6 +284,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
tags=["metadata"],
|
||||
category="metadata",
|
||||
version="1.0.0",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
class MetadataFieldExtractorInvocation(BaseInvocation):
|
||||
"""Extracts the text value from an image's metadata given a key.
|
||||
|
||||
1164
invokeai/app/invocations/metadata_linked.py
Normal file
1164
invokeai/app/invocations/metadata_linked.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -15,10 +14,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@@ -122,11 +119,10 @@ class ModelIdentifierOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"model_identifier",
|
||||
title="Model identifier",
|
||||
title="Any Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.0.1",
|
||||
)
|
||||
class ModelIdentifierInvocation(BaseInvocation):
|
||||
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
|
||||
@@ -144,10 +140,10 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
title="Main Model - SD1.5",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.4",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
@@ -181,7 +177,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
@invocation("lora_loader", title="Apply LoRA - SD1.5", tags=["model"], category="model", version="1.0.4")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@@ -244,7 +240,7 @@ class LoRASelectorOutput(BaseInvocationOutput):
|
||||
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
|
||||
|
||||
|
||||
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
|
||||
@invocation("lora_selector", title="Select LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
@@ -257,7 +253,9 @@ class LoRASelectorInvocation(BaseInvocation):
|
||||
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
|
||||
|
||||
|
||||
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.1.0")
|
||||
@invocation(
|
||||
"lora_collection_loader", title="Apply LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.2"
|
||||
)
|
||||
class LoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
@@ -320,10 +318,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_loader",
|
||||
title="SDXL LoRA",
|
||||
title="Apply LoRA - SDXL",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.5",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@@ -400,10 +398,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_collection_loader",
|
||||
title="SDXL LoRA Collection Loader",
|
||||
title="Apply LoRA Collection - SDXL",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
version="1.1.2",
|
||||
)
|
||||
class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
|
||||
@@ -469,7 +467,9 @@ class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
|
||||
@invocation(
|
||||
"vae_loader", title="VAE Model - SD1.5, SDXL, SD3, FLUX", tags=["vae", "model"], category="model", version="1.0.4"
|
||||
)
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@@ -496,10 +496,10 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"seamless",
|
||||
title="Seamless",
|
||||
title="Apply Seamless - SD1.5, SDXL",
|
||||
tags=["seamless", "model"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
version="1.0.2",
|
||||
)
|
||||
class SeamlessModeInvocation(BaseInvocation):
|
||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||
@@ -539,7 +539,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||
|
||||
|
||||
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.1")
|
||||
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2")
|
||||
class FreeUInvocation(BaseInvocation):
|
||||
"""
|
||||
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
|
||||
|
||||
@@ -72,10 +72,10 @@ class NoiseOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"noise",
|
||||
title="Noise",
|
||||
title="Create Latent Noise",
|
||||
tags=["latents", "noise"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.3",
|
||||
)
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
@@ -265,13 +265,9 @@ class ImageInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_dto = context.images.get_dto(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto=image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
@@ -6,7 +6,7 @@ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
@@ -23,7 +23,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
@@ -32,11 +32,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"sd3_denoise",
|
||||
title="SD3 Denoise",
|
||||
title="Denoise - SD3",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.1.1",
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
@@ -2,7 +2,7 @@ import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -21,11 +21,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"sd3_i2l",
|
||||
title="SD3 Image to Latents",
|
||||
title="Image to Latents - SD3",
|
||||
tags=["image", "latents", "vae", "i2l", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.0.1",
|
||||
)
|
||||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
@@ -24,10 +24,10 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@invocation(
|
||||
"sd3_l2i",
|
||||
title="SD3 Latents to Image",
|
||||
title="Latents to Image - SD3",
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.1",
|
||||
version="1.3.2",
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@@ -43,16 +43,11 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1230 # Determined experimentally.
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -14,7 +13,7 @@ from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
@@ -30,11 +29,10 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"sd3_model_loader",
|
||||
title="SD3 Main Model",
|
||||
title="Main Model - SD3",
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.0.1",
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
@@ -11,12 +11,12 @@ from transformers import (
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -29,11 +29,10 @@ SD3_T5_MAX_SEQ_LEN = 256
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
title="SD3 Text Encoding",
|
||||
title="Prompt - SD3",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.0.1",
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
@@ -2,7 +2,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@@ -24,7 +24,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
|
||||
@invocation("sdxl_model_loader", title="Main Model - SDXL", tags=["model", "sdxl"], category="model", version="1.0.4")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@@ -58,10 +58,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"sdxl_refiner_model_loader",
|
||||
title="SDXL Refiner Model",
|
||||
title="Refiner Model - SDXL",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.3",
|
||||
version="1.0.4",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
@@ -185,9 +185,9 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
assert (
|
||||
bounding_boxes is not None
|
||||
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||
assert bounding_boxes is not None, (
|
||||
"Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||
)
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
|
||||
@@ -45,7 +45,11 @@ class T2IAdapterOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
|
||||
"t2i_adapter",
|
||||
title="T2I-Adapter - SD1.5, SDXL",
|
||||
tags=["t2i_adapter", "control"],
|
||||
category="t2i_adapter",
|
||||
version="1.0.4",
|
||||
)
|
||||
class T2IAdapterInvocation(BaseInvocation):
|
||||
"""Collects T2I-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -7,7 +7,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
@@ -53,11 +53,10 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
|
||||
|
||||
@invocation(
|
||||
"tiled_multi_diffusion_denoise_latents",
|
||||
title="Tiled Multi-Diffusion Denoise Latents",
|
||||
title="Tiled Multi-Diffusion Denoise - SD1.5, SDXL",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
classification=Classification.Beta,
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
"""Tiled Multi-Diffusion denoising.
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -40,7 +39,6 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
@@ -74,7 +72,6 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
@@ -117,7 +114,6 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
@@ -168,7 +164,6 @@ class TileToPropertiesOutput(BaseInvocationOutput):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class TileToPropertiesInvocation(BaseInvocation):
|
||||
"""Split a Tile into its individual properties."""
|
||||
@@ -201,7 +196,6 @@ class PairTileImageOutput(BaseInvocationOutput):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class PairTileImageInvocation(BaseInvocation):
|
||||
"""Pair an image with its tile properties."""
|
||||
@@ -230,7 +224,6 @@ BLEND_MODES = Literal["Linear", "Seam"]
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Merge multiple tile images into a single image."""
|
||||
|
||||
@@ -9,6 +9,6 @@ def validate_weights(weights: Union[float, list[float]]) -> None:
|
||||
|
||||
|
||||
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
|
||||
"""Validate that begin_step_percent is less than end_step_percent"""
|
||||
if begin_step_percent >= end_step_percent:
|
||||
"""Validate that begin_step_percent is less than or equal to end_step_percent"""
|
||||
if begin_step_percent > end_step_percent:
|
||||
raise ValueError("Begin step percent must be less than or equal to end step percent")
|
||||
|
||||
@@ -1,12 +1,86 @@
|
||||
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
|
||||
import uvicorn
|
||||
|
||||
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
|
||||
def get_app():
|
||||
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
|
||||
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
|
||||
"""
|
||||
from invokeai.app.api_app import app, loop
|
||||
|
||||
return app, loop
|
||||
|
||||
|
||||
def run_app() -> None:
|
||||
# Before doing _anything_, parse CLI args!
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
"""The main entrypoint for the app."""
|
||||
# Parse the CLI arguments.
|
||||
InvokeAIArgs.parse_args()
|
||||
|
||||
from invokeai.app.api_app import invoke_api
|
||||
# Load config.
|
||||
app_config = get_config()
|
||||
|
||||
invoke_api()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
|
||||
# Configure the torch CUDA memory allocator.
|
||||
# NOTE: It is important that this happens before torch is imported.
|
||||
if app_config.pytorch_cuda_alloc_conf:
|
||||
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
|
||||
|
||||
# Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called.
|
||||
from invokeai.app.util.startup_utils import (
|
||||
apply_monkeypatches,
|
||||
check_cudnn,
|
||||
enable_dev_reload,
|
||||
find_open_port,
|
||||
register_mime_types,
|
||||
)
|
||||
|
||||
# Find an open port, and modify the config accordingly.
|
||||
first_open_port = find_open_port(app_config.port)
|
||||
if app_config.port != first_open_port:
|
||||
orig_config_port = app_config.port
|
||||
app_config.port = first_open_port
|
||||
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
|
||||
|
||||
# Miscellaneous startup tasks.
|
||||
apply_monkeypatches()
|
||||
register_mime_types()
|
||||
check_cudnn(logger)
|
||||
|
||||
# Initialize the app and event loop.
|
||||
app, loop = get_app()
|
||||
|
||||
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
|
||||
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
|
||||
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
|
||||
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
|
||||
|
||||
if app_config.dev_reload:
|
||||
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
|
||||
# imported.
|
||||
enable_dev_reload(custom_nodes_path=app_config.custom_nodes_path)
|
||||
|
||||
# Start the server.
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=app_config.port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
|
||||
@@ -27,7 +27,7 @@ class BoardImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str | Literal["none"],
|
||||
board_id: str,
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Literal, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
@@ -13,15 +12,9 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
@@ -29,8 +22,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
@@ -42,16 +35,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE image_name = ?;
|
||||
@@ -62,8 +53,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
@@ -72,138 +61,108 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str | Literal["none"],
|
||||
board_id: str,
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
params: list[str | bool] = []
|
||||
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
|
||||
# Add the board_id filter
|
||||
if board_id == "none":
|
||||
stmt += "AND board_images.board_id IS NULL"
|
||||
else:
|
||||
stmt += "AND board_images.board_id = ?"
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
params.append(is_intermediate)
|
||||
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
|
||||
# Execute the query
|
||||
self._cursor.execute(stmt, params)
|
||||
# Execute the query
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(stmt, params)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
|
||||
@@ -27,7 +27,7 @@ class BoardImagesServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str | Literal["none"],
|
||||
board_id: str,
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.board_images.board_images_base import BoardImagesServiceABC
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
@@ -26,7 +26,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str | Literal["none"],
|
||||
board_id: str,
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
|
||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||
@@ -19,20 +18,14 @@ from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
@@ -40,14 +33,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -55,8 +43,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
@@ -67,8 +55,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
@@ -76,8 +62,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
@@ -86,12 +72,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
@@ -102,11 +85,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
@@ -117,7 +99,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
@@ -128,7 +110,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
@@ -141,8 +123,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
@@ -153,11 +133,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
@@ -165,81 +144,67 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
self._cursor.execute(final_query, (limit, offset))
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
self._cursor.execute(count_query)
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
cursor = self._conn.cursor()
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
self._cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return boards
|
||||
|
||||
@@ -63,7 +63,11 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
|
||||
|
||||
def _board_handler(self, board_id: str) -> list[ImageDTO]:
|
||||
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(
|
||||
board_id,
|
||||
categories=None,
|
||||
is_intermediate=None,
|
||||
)
|
||||
return self._image_handler(image_names)
|
||||
|
||||
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||
|
||||
@@ -72,6 +72,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
outputs_dir: Path to directory for outputs.
|
||||
custom_nodes_dir: Path to directory for custom nodes.
|
||||
style_presets_dir: Path to directory for style presets.
|
||||
workflow_thumbnails_dir: Path to directory for workflow thumbnails.
|
||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
@@ -91,6 +92,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
|
||||
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
@@ -141,6 +143,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
|
||||
workflow_thumbnails_dir: Path = Field(default=Path("workflow_thumbnails"), description="Path to directory for workflow thumbnails.")
|
||||
|
||||
# LOGGING
|
||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||
@@ -169,6 +172,9 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
|
||||
|
||||
# PyTorch Memory Allocator
|
||||
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
@@ -300,6 +306,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the style presets directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.style_presets_dir)
|
||||
|
||||
@property
|
||||
def workflow_thumbnails_path(self) -> Path:
|
||||
"""Path to the workflow thumbnails directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.workflow_thumbnails_dir)
|
||||
|
||||
@property
|
||||
def convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
@@ -472,9 +483,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
assert config.schema_version == CONFIG_SCHEMA_VERSION, (
|
||||
f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
)
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
@@ -44,7 +44,8 @@ if TYPE_CHECKING:
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
)
|
||||
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -22,21 +21,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
@@ -44,12 +36,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -58,9 +47,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
@@ -68,7 +56,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -77,10 +65,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -88,10 +73,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
@@ -102,7 +87,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
@@ -113,7 +98,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
@@ -124,7 +109,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
@@ -137,8 +122,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@@ -152,110 +135,104 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND images.metadata LIKE ?
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND images.metadata LIKE ?
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
@@ -266,58 +243,48 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
try:
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
self._lock.acquire()
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
self._cursor.execute(query, image_names)
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
@@ -328,8 +295,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -346,8 +311,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
@@ -380,7 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
@@ -389,34 +354,30 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -265,7 +265,11 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def delete_images_on_board(self, board_id: str):
|
||||
try:
|
||||
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(
|
||||
board_id,
|
||||
categories=None,
|
||||
is_intermediate=None,
|
||||
)
|
||||
for image_name in image_names:
|
||||
self.__invoker.services.image_files.delete(image_name)
|
||||
self.__invoker.services.image_records.delete_many(image_names)
|
||||
@@ -278,7 +282,7 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Failed to delete image files")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem deleting image records and files")
|
||||
self.__invoker.services.logger.error(f"Problem deleting image records and files: {str(e)}")
|
||||
raise e
|
||||
|
||||
def delete_intermediates(self) -> int:
|
||||
|
||||
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.urls.urls_base import UrlServiceBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_base import WorkflowThumbnailServiceBase
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
|
||||
@@ -65,6 +66,7 @@ class InvocationServices:
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
style_preset_records: "StylePresetRecordsStorageBase",
|
||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
||||
workflow_thumbnails: "WorkflowThumbnailServiceBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@@ -91,3 +93,4 @@ class InvocationServices:
|
||||
self.conditioning = conditioning
|
||||
self.style_preset_records = style_preset_records
|
||||
self.style_preset_image_files = style_preset_image_files
|
||||
self.workflow_thumbnails = workflow_thumbnails
|
||||
|
||||
@@ -10,9 +10,9 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
|
||||
@@ -38,9 +38,9 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelConfigBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
@@ -49,8 +49,8 @@ from invokeai.backend.model_manager.metadata import (
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -182,9 +182,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or ModelRecordChanges()
|
||||
info: AnyModelConfig = ModelProbe.probe(
|
||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
||||
) # type: ignore
|
||||
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
|
||||
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@@ -644,12 +642,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
|
||||
config = config or ModelRecordChanges()
|
||||
hash_algo = self._app_config.hashing_algorithm
|
||||
fields = config.model_dump()
|
||||
|
||||
try:
|
||||
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
|
||||
except InvalidModelConfigException:
|
||||
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or ModelRecordChanges()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
||||
info = info or self._probe(model_path, config)
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
|
||||
@@ -11,7 +11,7 @@ from torch import load as torch_load
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
LoadedModelWithoutConfig,
|
||||
@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.load import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@@ -85,8 +86,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
if scan_result.scan_err:
|
||||
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
|
||||
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"LoadedModel",
|
||||
]
|
||||
|
||||
@@ -14,10 +14,12 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
|
||||
@@ -60,11 +60,9 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
@@ -78,7 +76,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
@@ -96,38 +93,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
|
||||
@@ -139,21 +136,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
@@ -164,23 +161,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@@ -192,33 +189,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
@@ -227,16 +224,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
count = 0
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -284,17 +280,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -313,34 +310,28 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
@@ -356,33 +347,32 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
cursor = self._db.conn.cursor()
|
||||
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
@@ -21,10 +21,16 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
:param output_dir: The folder where the serialized objects will be stored
|
||||
:param safe_globals: A list of types to be added to the safe globals for torch serialization
|
||||
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: Path,
|
||||
safe_globals: list[type],
|
||||
ephemeral: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
@@ -42,6 +48,8 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
torch.serialization.add_safe_globals(safe_globals) if safe_globals else None
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@@ -33,7 +33,7 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
from itertools import chain, product
|
||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
@@ -406,61 +406,143 @@ class IsFullResult(BaseModel):
|
||||
# region Util
|
||||
|
||||
|
||||
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
||||
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
|
||||
"""
|
||||
Populates the given graph with the given batch data items.
|
||||
"""
|
||||
graph_clone = graph.model_copy(deep=True)
|
||||
for item in node_field_values:
|
||||
node = graph_clone.get_node(item.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
setattr(node, item.field_name, item.value)
|
||||
graph_clone.update_node(item.node_path, node)
|
||||
return graph_clone
|
||||
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
|
||||
field_values_json for each session.
|
||||
|
||||
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
|
||||
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
|
||||
the field.
|
||||
|
||||
def create_session_nfv_tuples(
|
||||
batch: Batch, maximum: int
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
that was applied to the graph.
|
||||
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
|
||||
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
|
||||
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
|
||||
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
|
||||
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
|
||||
Each inner list now represents the substitution values for a single permutation (session).
|
||||
- For each permutation, substitute the values into the graph
|
||||
|
||||
This function is optimized for performance, as it is used to generate a large number of sessions at once.
|
||||
|
||||
Args:
|
||||
batch: The batch to generate sessions from
|
||||
maximum: The maximum number of sessions to generate
|
||||
|
||||
Returns:
|
||||
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
|
||||
generator will stop early if the maximum number of sessions is reached.
|
||||
"""
|
||||
|
||||
# TODO: Should this be a class method on Batch?
|
||||
|
||||
data: list[list[tuple[NodeFieldValue]]] = []
|
||||
data: list[list[tuple[dict]]] = []
|
||||
batch_data_collection = batch.data if batch.data is not None else []
|
||||
for batch_datum_list in batch_data_collection:
|
||||
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
||||
|
||||
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
||||
for batch_datum_list in batch_data_collection:
|
||||
node_field_values_to_zip: list[list[dict]] = []
|
||||
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
|
||||
for batch_datum in batch_datum_list:
|
||||
node_field_values = [
|
||||
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
||||
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
|
||||
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
|
||||
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
# Zip the dicts together to create a list of dicts for each permutation
|
||||
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
# We serialize the graph and session once, then mutate the graph dict in place for each session.
|
||||
#
|
||||
# This sounds scary, but it's actually fine.
|
||||
#
|
||||
# The batch prep logic injects field values into the same fields for each generated session.
|
||||
#
|
||||
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
|
||||
# [
|
||||
# (
|
||||
# {"node_path": "1", "field_name": "a", "value": 1},
|
||||
# {"node_path": "2", "field_name": "b", "value": 2},
|
||||
# {"node_path": "3", "field_name": "c", "value": 3},
|
||||
# ),
|
||||
# (
|
||||
# {"node_path": "1", "field_name": "a", "value": 4},
|
||||
# {"node_path": "2", "field_name": "b", "value": 5},
|
||||
# {"node_path": "3", "field_name": "c", "value": 6},
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
|
||||
# No matter the complexity of the batch, this property holds true.
|
||||
#
|
||||
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
|
||||
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
|
||||
#
|
||||
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
|
||||
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
|
||||
# but this was also slow.
|
||||
#
|
||||
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
|
||||
# objects for each session.
|
||||
#
|
||||
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
|
||||
# dict as the session's graph.
|
||||
|
||||
# Dump the batch's graph to a dict once
|
||||
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
|
||||
# overwritten for each session by the mutated graph_as_dict.
|
||||
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
|
||||
|
||||
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
|
||||
count = 0
|
||||
|
||||
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
|
||||
# still limited by the maximum number of sessions.
|
||||
for _ in range(batch.runs):
|
||||
for d in product(*data):
|
||||
if count >= maximum:
|
||||
# We've reached the maximum number of sessions we may generate
|
||||
return
|
||||
|
||||
# Flatten the list of lists of dicts into a single list of dicts
|
||||
# TODO(psyche): Is the a more efficient way to do this?
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
||||
|
||||
# Need a fresh ID for each session
|
||||
session_id = uuid_string()
|
||||
|
||||
# Mutate the session dict in place
|
||||
session_dict["id"] = session_id
|
||||
|
||||
# Substitute the values into the graph
|
||||
for nfv in flat_node_field_values:
|
||||
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
|
||||
|
||||
# Mutate the session dict in place
|
||||
session_dict["graph"] = graph_as_dict
|
||||
|
||||
# Serialize the session and field values
|
||||
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
|
||||
session_json = json.dumps(session_dict, default=to_jsonable_python)
|
||||
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
|
||||
|
||||
# Yield the session_id, session_json, and field_values_json
|
||||
yield (session_id, session_json, field_values_json)
|
||||
|
||||
# Increment the count so we know when to stop
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring
|
||||
the overhead of actually generating them. Adapted from `create_sessions().
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
@@ -476,42 +558,78 @@ def calc_session_count(batch: Batch) -> int:
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
class SessionQueueValueToInsert(NamedTuple):
|
||||
"""A tuple of values to insert into the session_queue table"""
|
||||
ValueToInsertTuple: TypeAlias = tuple[
|
||||
str, # queue_id
|
||||
str, # session (as stringified JSON)
|
||||
str, # session_id
|
||||
str, # batch_id
|
||||
str | None, # field_values (optional, as stringified JSON)
|
||||
int, # priority
|
||||
str | None, # workflow (optional, as stringified JSON)
|
||||
str | None, # origin (optional)
|
||||
str | None, # destination (optional)
|
||||
int | None, # retried_from_item_id (optional, this is always None for new items)
|
||||
]
|
||||
"""A type alias for the tuple of values to insert into the session queue table.
|
||||
|
||||
# Careful with the ordering of this - it must match the insert statement
|
||||
queue_id: str # queue_id
|
||||
session: str # session json
|
||||
session_id: str # session_id
|
||||
batch_id: str # batch_id
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
destination: str | None
|
||||
retried_from_item_id: int | None = None
|
||||
**If you change this, be sure to update the `enqueue_batch` and `retry_items_by_id` methods in the session queue service!**
|
||||
"""
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
def prepare_values_to_insert(
|
||||
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
|
||||
) -> list[ValueToInsertTuple]:
|
||||
"""
|
||||
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
|
||||
`executemany` statement to insert multiple rows at once.
|
||||
|
||||
Args:
|
||||
queue_id: The ID of the queue to insert the items into
|
||||
batch: The batch to prepare the values for
|
||||
priority: The priority of the queue items
|
||||
max_new_queue_items: The maximum number of queue items to insert
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
||||
values_to_insert: ValuesToInsert = []
|
||||
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
# sessions must have unique id
|
||||
session.id = uuid_string()
|
||||
Returns:
|
||||
A list of tuples to insert into the session queue table. Each tuple contains the following values:
|
||||
- queue_id
|
||||
- session (as stringified JSON)
|
||||
- session_id
|
||||
- batch_id
|
||||
- field_values (optional, as stringified JSON)
|
||||
- priority
|
||||
- workflow (optional, as stringified JSON)
|
||||
- origin (optional)
|
||||
- destination (optional)
|
||||
- retried_from_item_id (optional, this is always None for new items)
|
||||
"""
|
||||
|
||||
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
|
||||
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
|
||||
# this difference becomes noticeable.
|
||||
#
|
||||
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
|
||||
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
|
||||
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
|
||||
# not support by default. Apparently there are sets somewhere in the graph.
|
||||
|
||||
# The same workflow is used for all sessions in the batch - serialize it once
|
||||
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
|
||||
|
||||
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
values_to_insert.append(
|
||||
SessionQueueValueToInsert(
|
||||
queue_id=queue_id,
|
||||
session=session.model_dump_json(warnings=False, exclude_none=True), # as json
|
||||
session_id=session.id,
|
||||
batch_id=batch.batch_id,
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
field_values=json.dumps(field_values, default=to_jsonable_python) if field_values else None, # as json
|
||||
priority=priority,
|
||||
workflow=json.dumps(workflow, default=to_jsonable_python) if workflow else None, # as json
|
||||
origin=batch.origin,
|
||||
destination=batch.destination,
|
||||
(
|
||||
queue_id,
|
||||
session_json,
|
||||
session_id,
|
||||
batch.batch_id,
|
||||
field_values_json,
|
||||
priority,
|
||||
workflow_json,
|
||||
batch.origin,
|
||||
batch.destination,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic_core import to_jsonable_python
|
||||
@@ -27,7 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
SessionQueueValueToInsert,
|
||||
ValueToInsertTuple,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
@@ -38,9 +38,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
__invoker: Invoker
|
||||
__conn: sqlite3.Connection
|
||||
__cursor: sqlite3.Cursor
|
||||
__lock: threading.RLock
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
@@ -56,9 +53,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self.__lock = db.lock
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self._conn = db.conn
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
@@ -66,8 +61,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -75,14 +70,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
@@ -92,11 +86,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, self.__cursor.fetchone()[0])
|
||||
return cast(int, cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
@@ -106,12 +101,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
|
||||
|
||||
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
@@ -133,19 +130,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if requested_count > enqueued_count:
|
||||
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||
|
||||
self.__cursor.executemany(
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
@@ -157,25 +152,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -183,52 +172,40 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -242,8 +219,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
|
||||
@@ -251,12 +228,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(status, error_type, error_message, error_traceback, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
@@ -264,48 +239,36 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -313,8 +276,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
@@ -322,17 +285,16 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -342,8 +304,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
OR status = 'canceled'
|
||||
)
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -351,8 +312,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
@@ -360,12 +321,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
@@ -394,8 +353,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
@@ -406,7 +365,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -414,8 +373,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -423,20 +382,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -446,7 +403,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -454,8 +411,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -463,20 +420,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id is ?
|
||||
@@ -485,7 +440,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id]
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -493,8 +448,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -502,7 +457,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
@@ -510,21 +465,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND status == 'pending'
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -532,8 +485,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -541,43 +494,35 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
session_json = session.model_dump_json(warnings=False, exclude_none=True)
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET session = ?
|
||||
@@ -585,12 +530,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
@@ -601,83 +544,71 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
try:
|
||||
item_id = cursor
|
||||
self.__lock.acquire()
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
self.__cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
params.append(status)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] for row in counts_result)
|
||||
@@ -696,29 +627,23 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
@@ -734,24 +659,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
@@ -770,9 +689,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
values_to_insert: list[SessionQueueValueToInsert] = []
|
||||
cursor = self._conn.cursor()
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
for item_id in item_ids:
|
||||
@@ -798,23 +716,23 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
else queue_item.item_id
|
||||
)
|
||||
|
||||
value_to_insert = SessionQueueValueToInsert(
|
||||
queue_id=queue_item.queue_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
destination=queue_item.destination,
|
||||
field_values=field_values_json,
|
||||
origin=queue_item.origin,
|
||||
priority=queue_item.priority,
|
||||
workflow=workflow_json,
|
||||
session=cloned_session_json,
|
||||
session_id=cloned_session.id,
|
||||
retried_from_item_id=retried_from_item_id,
|
||||
value_to_insert: ValueToInsertTuple = (
|
||||
queue_item.queue_id,
|
||||
cloned_session_json,
|
||||
cloned_session.id,
|
||||
queue_item.batch_id,
|
||||
field_values_json,
|
||||
queue_item.priority,
|
||||
workflow_json,
|
||||
queue_item.origin,
|
||||
queue_item.destination,
|
||||
retried_from_item_id,
|
||||
)
|
||||
values_to_insert.append(value_to_insert)
|
||||
|
||||
# TODO(psyche): Handle max queue size?
|
||||
|
||||
self.__cursor.executemany(
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
@@ -822,12 +740,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
retry_result = RetryItemsResult(
|
||||
queue_id=queue_id,
|
||||
retried_item_ids=retried_item_ids,
|
||||
|
||||
@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationRegistry,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
return InvocationRegistry.get_output_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordOrderBy
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
@@ -16,16 +17,13 @@ from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
@@ -102,7 +100,9 @@ class BoardsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A list of all boards.
|
||||
"""
|
||||
return self._services.boards.get_all()
|
||||
return self._services.boards.get_all(
|
||||
order_by=BoardRecordOrderBy.CreatedAt, direction=SQLiteDirection.Descending
|
||||
)
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
"""Adds an image to a board.
|
||||
@@ -122,7 +122,11 @@ class BoardsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A list of all image names for the board.
|
||||
"""
|
||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||
return self._services.board_images.get_all_board_image_names_for_board(
|
||||
board_id,
|
||||
categories=None,
|
||||
is_intermediate=None,
|
||||
)
|
||||
|
||||
|
||||
class LoggerInterface(InvocationContextInterface):
|
||||
@@ -283,7 +287,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The local path of the image or thumbnail.
|
||||
"""
|
||||
return self._services.images.get_path(image_name, thumbnail)
|
||||
return Path(self._services.images.get_path(image_name, thumbnail))
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
@@ -38,14 +37,20 @@ class SqliteDatabase:
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
|
||||
# Enable foreign key constraints
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode = WAL;")
|
||||
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
@@ -53,15 +58,14 @@ class SqliteDatabase:
|
||||
# No need to clean in-memory database
|
||||
if not self.db_path:
|
||||
return
|
||||
with self.lock:
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@@ -19,6 +19,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import build_migration_16
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -55,6 +57,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_14())
|
||||
migrator.register_migration(build_migration_15())
|
||||
migrator.register_migration(build_migration_16())
|
||||
migrator.register_migration(build_migration_17())
|
||||
migrator.register_migration(build_migration_18())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration17Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_workflows_tags_col(cursor)
|
||||
|
||||
def _add_workflows_tags_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `tags` column to the workflow_library table. It is a generated column that extracts the tags from the
|
||||
workflow JSON.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"ALTER TABLE workflow_library ADD COLUMN tags TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.tags')) VIRTUAL;"
|
||||
)
|
||||
|
||||
|
||||
def build_migration_17() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 16 to 17.
|
||||
|
||||
This migration does the following:
|
||||
- Adds `tags` column to the workflow_library table. It is a generated column that extracts the tags from the
|
||||
workflow JSON.
|
||||
"""
|
||||
migration_17 = Migration(
|
||||
from_version=16,
|
||||
to_version=17,
|
||||
callback=Migration17Callback(),
|
||||
)
|
||||
|
||||
return migration_17
|
||||
@@ -0,0 +1,47 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration18Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._make_workflow_opened_at_nullable(cursor)
|
||||
|
||||
def _make_workflow_opened_at_nullable(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Make the `opened_at` column nullable in the `workflow_library` table. This is accomplished by:
|
||||
- Dropping the existing `idx_workflow_library_opened_at` index (must be done before dropping the column)
|
||||
- Dropping the existing `opened_at` column
|
||||
- Adding a new nullable column `opened_at` (no data migration needed, all values will be NULL)
|
||||
- Adding a new `idx_workflow_library_opened_at` index on the `opened_at` column
|
||||
"""
|
||||
# For index renaming in SQLite, we need to drop and recreate
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_workflow_library_opened_at;")
|
||||
# Rename existing column to deprecated
|
||||
cursor.execute("ALTER TABLE workflow_library DROP COLUMN opened_at;")
|
||||
# Add new nullable column - all values will be NULL - no migration of data needed
|
||||
cursor.execute("ALTER TABLE workflow_library ADD COLUMN opened_at DATETIME;")
|
||||
# Create new index on the new column
|
||||
cursor.execute(
|
||||
"CREATE INDEX idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||
)
|
||||
|
||||
|
||||
def build_migration_18() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 17 to 18.
|
||||
|
||||
This migration does the following:
|
||||
- Make the `opened_at` column nullable in the `workflow_library` table. This is accomplished by:
|
||||
- Dropping the existing `idx_workflow_library_opened_at` index (must be done before dropping the column)
|
||||
- Dropping the existing `opened_at` column
|
||||
- Adding a new nullable column `opened_at` (no data migration needed, all values will be NULL)
|
||||
- Adding a new `idx_workflow_library_opened_at` index on the `opened_at` column
|
||||
"""
|
||||
migration_18 = Migration(
|
||||
from_version=17,
|
||||
to_version=18,
|
||||
callback=Migration18Callback(),
|
||||
)
|
||||
|
||||
return migration_18
|
||||
@@ -43,46 +43,45 @@ class SqliteMigrator:
|
||||
|
||||
def run_migrations(self) -> bool:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._db.lock:
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db.lock, self._db.conn as conn:
|
||||
with self._db.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
@@ -108,27 +107,26 @@ class SqliteMigrator:
|
||||
|
||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the migrations table for the database, if one does not already exist."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
@classmethod
|
||||
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user