mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 14:27:56 -05:00
Compare commits
516 Commits
feat/seaml
...
lstein/mod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53e1199902 | ||
|
|
0f9c676fcb | ||
|
|
a51b165a40 | ||
|
|
5f80d4dd07 | ||
|
|
b708aef5cc | ||
|
|
aace679505 | ||
|
|
a2079bdd70 | ||
|
|
0a0412f75f | ||
|
|
e079cc9f07 | ||
|
|
76aa19a0f7 | ||
|
|
71e7e61c0f | ||
|
|
67607f053d | ||
|
|
4bab724288 | ||
|
|
e50a257198 | ||
|
|
4149d357bf | ||
|
|
33d4756c48 | ||
|
|
1f751f8c21 | ||
|
|
ca95a3bd0d | ||
|
|
55b40a9425 | ||
|
|
90083cc88d | ||
|
|
3962914f7d | ||
|
|
3644d40e04 | ||
|
|
fe1038665c | ||
|
|
a80ff75b52 | ||
|
|
ead754432a | ||
|
|
ce2baa36a9 | ||
|
|
bccfe8b3cc | ||
|
|
fa9ea93477 | ||
|
|
fe0cf2c160 | ||
|
|
e5b2bc8532 | ||
|
|
a64a34b49a | ||
|
|
51060543dc | ||
|
|
7f68f58cf7 | ||
|
|
432231ea18 | ||
|
|
44216381cb | ||
|
|
00e85bcd67 | ||
|
|
6303f74616 | ||
|
|
a681fa4b03 | ||
|
|
1cc686734b | ||
|
|
82e8b92ba0 | ||
|
|
e86658f864 | ||
|
|
ad136c2680 | ||
|
|
35374ec531 | ||
|
|
ed82bf6bb8 | ||
|
|
078c9b6964 | ||
|
|
1a9d2f1701 | ||
|
|
3e93159bce | ||
|
|
b57ebe52e4 | ||
|
|
ba4616ff89 | ||
|
|
dcfbd49e1b | ||
|
|
913fc83cbf | ||
|
|
6b8ce34eb3 | ||
|
|
9508e0c9db | ||
|
|
9c720da021 | ||
|
|
e1b576c72d | ||
|
|
971ccfb081 | ||
|
|
43a3c3c7ea | ||
|
|
4df1cdb34d | ||
|
|
3f860c3523 | ||
|
|
d8d0c9af09 | ||
|
|
9403672ac0 | ||
|
|
94591840a7 | ||
|
|
26b91a538a | ||
|
|
7ca456d674 | ||
|
|
78828b6b9c | ||
|
|
166ff9d301 | ||
|
|
4f97bd4418 | ||
|
|
e0e001758a | ||
|
|
8e06088152 | ||
|
|
c1887135b3 | ||
|
|
096d195d6e | ||
|
|
7870b90717 | ||
|
|
9854b244fd | ||
|
|
7d800e1ce3 | ||
|
|
1c8b1fbc53 | ||
|
|
594a3aef93 | ||
|
|
78377469db | ||
|
|
9cbc62d8d3 | ||
|
|
cd5d3e30c7 | ||
|
|
fbe6452c45 | ||
|
|
3f4ea073d1 | ||
|
|
8b7f8eaea2 | ||
|
|
88e16ce051 | ||
|
|
421440cae0 | ||
|
|
421021cede | ||
|
|
020d4302d1 | ||
|
|
8c59d2e5af | ||
|
|
17d451eaa7 | ||
|
|
23a06fd06d | ||
|
|
010c8e8038 | ||
|
|
dfc635223c | ||
|
|
37121a3a24 | ||
|
|
51b5de799a | ||
|
|
eadbe6abf7 | ||
|
|
16f48a816f | ||
|
|
95838e5559 | ||
|
|
3e8d62b1d1 | ||
|
|
2acc93eb8e | ||
|
|
cb0fdf3394 | ||
|
|
a180c0f241 | ||
|
|
16ec7a323b | ||
|
|
de90d4068b | ||
|
|
fbb61f2334 | ||
|
|
be85c7972b | ||
|
|
3a586fc9c4 | ||
|
|
4624de0151 | ||
|
|
459f0238dd | ||
|
|
dedead672f | ||
|
|
67366921c0 | ||
|
|
5a1019d858 | ||
|
|
f4ba7be918 | ||
|
|
069d8b5812 | ||
|
|
24d73d484a | ||
|
|
e3912e8826 | ||
|
|
062a6ed180 | ||
|
|
2479a59e5e | ||
|
|
7d0ac2c36d | ||
|
|
519b892f0c | ||
|
|
763dcacfd3 | ||
|
|
3599d546e6 | ||
|
|
22a84930f6 | ||
|
|
d64e17e043 | ||
|
|
ba54277011 | ||
|
|
5915a4a51c | ||
|
|
4580ba0d87 | ||
|
|
b9fd2e9e76 | ||
|
|
75b65597af | ||
|
|
2a3c0ab5d2 | ||
|
|
7d61373b82 | ||
|
|
7d65555a5a | ||
|
|
123f2b2dbc | ||
|
|
1e4e42556e | ||
|
|
1f6699ac43 | ||
|
|
ace8665411 | ||
|
|
7fa5bae8fd | ||
|
|
f9faca7c91 | ||
|
|
594fd3ba6d | ||
|
|
44d68f5ed5 | ||
|
|
4bda7d7df5 | ||
|
|
920c5dd686 | ||
|
|
4ce00a32f4 | ||
|
|
dcbb25dfea | ||
|
|
6c8270dae2 | ||
|
|
48c3d926b0 | ||
|
|
63f6c12aa3 | ||
|
|
c91429d4ab | ||
|
|
b19572199f | ||
|
|
a673c0aa14 | ||
|
|
955ef3bc54 | ||
|
|
f002ae8da5 | ||
|
|
208bf68ba2 | ||
|
|
1aba369c83 | ||
|
|
9ac11e793c | ||
|
|
9b39888e2f | ||
|
|
c1715144f0 | ||
|
|
230ee18536 | ||
|
|
c025c9c4ed | ||
|
|
929557bc6f | ||
|
|
811dd93912 | ||
|
|
acaaff4b7e | ||
|
|
807ae821ea | ||
|
|
208d390779 | ||
|
|
9a60dbd5cb | ||
|
|
637c5b0747 | ||
|
|
27164de8b8 | ||
|
|
08e40d6d16 | ||
|
|
d905c54795 | ||
|
|
dc1e804887 | ||
|
|
95fd2ee6ff | ||
|
|
5f4eb0c3b3 | ||
|
|
cbf0310a2c | ||
|
|
4555aec17c | ||
|
|
3b832f1db2 | ||
|
|
d464ce509b | ||
|
|
2f16a2c35d | ||
|
|
3909e68527 | ||
|
|
848e51f72b | ||
|
|
52f8c9e16f | ||
|
|
5174f382b9 | ||
|
|
81fce18c73 | ||
|
|
c7f80cd163 | ||
|
|
309e2414ce | ||
|
|
6704f77d87 | ||
|
|
045d3f6139 | ||
|
|
0b75a4fbb5 | ||
|
|
a0bd8c638e | ||
|
|
de04a5f441 | ||
|
|
40ed218c26 | ||
|
|
807c6b41c5 | ||
|
|
f6bbcd0589 | ||
|
|
ada22a799e | ||
|
|
a42ef9c855 | ||
|
|
034af2d9f8 | ||
|
|
676ccd8ebb | ||
|
|
a263a4f4cc | ||
|
|
ef0754cdec | ||
|
|
8158124679 | ||
|
|
5d31df0cb7 | ||
|
|
bd63454e51 | ||
|
|
062df07de2 | ||
|
|
0fc14afcf0 | ||
|
|
4a0a1c30db | ||
|
|
3432fd72f8 | ||
|
|
05a43c41f9 | ||
|
|
bb48617101 | ||
|
|
aa2f68f608 | ||
|
|
fbccce7573 | ||
|
|
a35087ee6e | ||
|
|
03e463dc89 | ||
|
|
d467e138a4 | ||
|
|
ba4aaea45b | ||
|
|
53eb23b8b6 | ||
|
|
8b969053e7 | ||
|
|
98a076260b | ||
|
|
164877b610 | ||
|
|
b3f4f28d76 | ||
|
|
acee4bd282 | ||
|
|
fc9a7320eb | ||
|
|
7c0a083b13 | ||
|
|
50d254fdb7 | ||
|
|
0cfc1c5f86 | ||
|
|
f35dfa06bb | ||
|
|
407bca5063 | ||
|
|
1419977e89 | ||
|
|
a953944894 | ||
|
|
a4cdaa245e | ||
|
|
105a4234b0 | ||
|
|
34c563060f | ||
|
|
d45c47db81 | ||
|
|
c771a4027f | ||
|
|
3fd27b1aa9 | ||
|
|
d59e534cad | ||
|
|
0c97a1e7e7 | ||
|
|
c8b306d9f8 | ||
|
|
edd2c54b9e | ||
|
|
727cc0dafe | ||
|
|
4530bd46dc | ||
|
|
c8b109f52e | ||
|
|
2e9a7b0454 | ||
|
|
1d6a4e7ee7 | ||
|
|
a2613948d8 | ||
|
|
f8392b2f78 | ||
|
|
358116bc22 | ||
|
|
1e3590111d | ||
|
|
063b800280 | ||
|
|
3935bf92c8 | ||
|
|
066e09b517 | ||
|
|
869b4a8d49 | ||
|
|
399ebe443e | ||
|
|
13919ff300 | ||
|
|
634e5652ef | ||
|
|
effced8560 | ||
|
|
ac4634000a | ||
|
|
9bdc718df5 | ||
|
|
73ca8ccdb3 | ||
|
|
f37ffda966 | ||
|
|
5a9777d443 | ||
|
|
8072c05ee0 | ||
|
|
75ff4f4ca3 | ||
|
|
30df123221 | ||
|
|
06193ddbe8 | ||
|
|
ce5122f87c | ||
|
|
43ebd68313 | ||
|
|
ec19fcafb1 | ||
|
|
6fcc7d4c4b | ||
|
|
912087e4dc | ||
|
|
593fb95213 | ||
|
|
6d821b32d3 | ||
|
|
297f96c16b | ||
|
|
0e53b27655 | ||
|
|
35ae9f6e71 | ||
|
|
a1d9e6b871 | ||
|
|
f05379f965 | ||
|
|
e34e6d6e80 | ||
|
|
f9b92ddc12 | ||
|
|
8bc1ca046c | ||
|
|
86cb53342a | ||
|
|
e3de996525 | ||
|
|
25a71a1791 | ||
|
|
6edee2d22b | ||
|
|
ab58eb29c5 | ||
|
|
d16583ad1c | ||
|
|
46db1dd18f | ||
|
|
d5d517d2fa | ||
|
|
d2cdbe5c4e | ||
|
|
4c9344b0ee | ||
|
|
cba31efd78 | ||
|
|
4d01b5c0f2 | ||
|
|
e02af8f518 | ||
|
|
c485cf568b | ||
|
|
51451cbf21 | ||
|
|
0363a06963 | ||
|
|
cc280cbef1 | ||
|
|
7544eadd48 | ||
|
|
7d683b4db6 | ||
|
|
07ddd601e1 | ||
|
|
60b3c6a201 | ||
|
|
88c8cb61f0 | ||
|
|
43fbac26df | ||
|
|
627444e17c | ||
|
|
5601858f4f | ||
|
|
c9cd418ed8 | ||
|
|
b152fbf72f | ||
|
|
f95111772a | ||
|
|
14ce7cf09c | ||
|
|
28a1a6939f | ||
|
|
6d2b4013f8 | ||
|
|
30aea54f1a | ||
|
|
ca7a7b57bb | ||
|
|
c5d0e65a24 | ||
|
|
6cc7b55ec5 | ||
|
|
883e9973ec | ||
|
|
9e7d829906 | ||
|
|
456a0a59e0 | ||
|
|
4f2bf7e7e8 | ||
|
|
77e93888cf | ||
|
|
fa54974bff | ||
|
|
7ac99d6bc3 | ||
|
|
aa82f9360c | ||
|
|
5aefa49d7d | ||
|
|
b6e9cd4fe2 | ||
|
|
6d1057c560 | ||
|
|
b4790002c7 | ||
|
|
e02700a782 | ||
|
|
83ce8ef1ec | ||
|
|
19e487b5ee | ||
|
|
aa4b56baf2 | ||
|
|
d3a2be69f1 | ||
|
|
02c087ee37 | ||
|
|
cab8d9bb20 | ||
|
|
28e6a7139b | ||
|
|
1625854eaf | ||
|
|
3199409fd3 | ||
|
|
f87b042162 | ||
|
|
3402cf6542 | ||
|
|
183e2c3ee0 | ||
|
|
098d506b95 | ||
|
|
7aa33c352b | ||
|
|
bf62553150 | ||
|
|
2b08d9e53b | ||
|
|
8954953eca | ||
|
|
ed91f48a92 | ||
|
|
eb2fcbe28a | ||
|
|
e78b36a9f7 | ||
|
|
144ede031e | ||
|
|
8ca37bba33 | ||
|
|
a608340c89 | ||
|
|
7fecebf7db | ||
|
|
b915d74127 | ||
|
|
6ec347bd41 | ||
|
|
e54843acc9 | ||
|
|
0960518088 | ||
|
|
21de74fac4 | ||
|
|
8ce9b6c51e | ||
|
|
b64ade586d | ||
|
|
3c44a74ba5 | ||
|
|
24d0901d8e | ||
|
|
b1b5f70ea6 | ||
|
|
6392098961 | ||
|
|
2c39aec22d | ||
|
|
d066bc6d19 | ||
|
|
e487bcd0f7 | ||
|
|
e0f8274f49 | ||
|
|
69e3513e90 | ||
|
|
7e706f02cb | ||
|
|
41dad2013a | ||
|
|
3f554d6824 | ||
|
|
202c5a48c6 | ||
|
|
2d71f6f4b8 | ||
|
|
0420874f56 | ||
|
|
f222b871e9 | ||
|
|
8b8d589033 | ||
|
|
f4c895257a | ||
|
|
10af5a26f2 | ||
|
|
1088adeb0a | ||
|
|
ad49380cd1 | ||
|
|
b2fe24c401 | ||
|
|
b128db1d58 | ||
|
|
f7f0630d97 | ||
|
|
5075e9c899 | ||
|
|
3c1549cf5c | ||
|
|
9faa53ceb1 | ||
|
|
32672cfeda | ||
|
|
b5266f89ad | ||
|
|
7a3b467ce0 | ||
|
|
bdfdf854fc | ||
|
|
1c38cce16d | ||
|
|
4cdca45228 | ||
|
|
bfed08673a | ||
|
|
c1aa2b82eb | ||
|
|
0a09f84b07 | ||
|
|
b7938d9ca9 | ||
|
|
977e348a35 | ||
|
|
de666fd7bc | ||
|
|
73bc088fa7 | ||
|
|
578e682562 | ||
|
|
0c8849155e | ||
|
|
d1382f232c | ||
|
|
d7ebe3f048 | ||
|
|
5c2bdf626b | ||
|
|
390a1c9fbb | ||
|
|
c46d9b8768 | ||
|
|
ef8d9843dd | ||
|
|
dc2e1a42bc | ||
|
|
151ba02022 | ||
|
|
d051c0868e | ||
|
|
238d7fa0ee | ||
|
|
f0ce559d28 | ||
|
|
e880f4bcfb | ||
|
|
539776a15a | ||
|
|
c029534243 | ||
|
|
dc683475d4 | ||
|
|
c090c5f907 | ||
|
|
db7fdc3555 | ||
|
|
b9a90fbd28 | ||
|
|
08952b9aa0 | ||
|
|
b7789bb7bb | ||
|
|
3529925234 | ||
|
|
e7a10d310f | ||
|
|
2ce07a4730 | ||
|
|
45d5ab20ec | ||
|
|
7bf7c16a5d | ||
|
|
afe9756667 | ||
|
|
fcea65770f | ||
|
|
a033ccc776 | ||
|
|
716a1b6423 | ||
|
|
171d789646 | ||
|
|
ac88863fd2 | ||
|
|
27dcd89c90 | ||
|
|
4b932b275d | ||
|
|
d219167849 | ||
|
|
090db1ab3a | ||
|
|
7b2e6deaf1 | ||
|
|
63f94579c5 | ||
|
|
3dfff278aa | ||
|
|
6d8b2a7385 | ||
|
|
88db094cf2 | ||
|
|
7430d87301 | ||
|
|
b583bddeb1 | ||
|
|
f454304c91 | ||
|
|
8052f2eb5d | ||
|
|
8636015d92 | ||
|
|
b7a6a536e6 | ||
|
|
b2892f9068 | ||
|
|
3582cfa267 | ||
|
|
64424c6db0 | ||
|
|
598fe8101e | ||
|
|
b7ca983f9c | ||
|
|
2165d55a67 | ||
|
|
a7aca29765 | ||
|
|
79b2423159 | ||
|
|
b09e012baa | ||
|
|
c9a016f1a2 | ||
|
|
b5e1ba34b3 | ||
|
|
d979c50de3 | ||
|
|
11ead34022 | ||
|
|
82499d4ef0 | ||
|
|
3448edac1a | ||
|
|
626acd5105 | ||
|
|
404cfe0eb9 | ||
|
|
e9074176bd | ||
|
|
ca6d24810c | ||
|
|
57552deab2 | ||
|
|
58aa159a50 | ||
|
|
d8f7c19030 | ||
|
|
8f51adc737 | ||
|
|
d1c5990abe | ||
|
|
8fc20925b5 | ||
|
|
869f310ae7 | ||
|
|
7df67d077a | ||
|
|
bc1bce18b0 | ||
|
|
e6512e1b9a | ||
|
|
8396bf7c99 | ||
|
|
24132a7950 | ||
|
|
dff466244d | ||
|
|
97f2e778ee | ||
|
|
93cef55964 | ||
|
|
055ad0101d | ||
|
|
9adc897302 | ||
|
|
4b3d54dbc0 | ||
|
|
6f9bf87a7a | ||
|
|
f023e342ef | ||
|
|
1784aeb343 | ||
|
|
0deb3f9e2a | ||
|
|
45d172d5a8 | ||
|
|
f5d95ffed5 | ||
|
|
6f9c1c6d4e | ||
|
|
811c82a677 | ||
|
|
4f0e43ec1b | ||
|
|
26a7b7b66d | ||
|
|
8611ffe32d | ||
|
|
3cb6d333f6 | ||
|
|
4570702dd0 | ||
|
|
1d107f30e5 | ||
|
|
79084e9e20 | ||
|
|
916cc26193 | ||
|
|
fc9b4539a3 | ||
|
|
e83d00595d | ||
|
|
1c7d9dbf40 | ||
|
|
7db71ed42e | ||
|
|
09ef57718e | ||
|
|
cab8239ba8 | ||
|
|
c56fb38855 | ||
|
|
155d9fcb13 | ||
|
|
81da3d3b23 | ||
|
|
51e84e6986 | ||
|
|
1ea0ccb7b9 | ||
|
|
5434dcd273 | ||
|
|
0c7430048e | ||
|
|
6c9b9e1787 | ||
|
|
b2894b5270 | ||
|
|
32958db6f6 | ||
|
|
e8815a1676 | ||
|
|
e8edb0d434 | ||
|
|
b5d97b18f1 | ||
|
|
ae56c000fc |
@@ -47,34 +47,9 @@ pip install ".[dev,test]"
|
||||
These are optional groups of packages which are defined within the `pyproject.toml`
|
||||
and will be required for testing the changes you make to the code.
|
||||
|
||||
### Running Tests
|
||||
|
||||
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
|
||||
be found under the `./tests` folder and can be run with a single `pytest`
|
||||
command. Optionally, to review test coverage you can append `--cov`.
|
||||
|
||||
```zsh
|
||||
pytest --cov
|
||||
```
|
||||
|
||||
Test outcomes and coverage will be reported in the terminal. In addition a more
|
||||
detailed report is created in both XML and HTML format in the `./coverage`
|
||||
folder. The HTML one in particular can help identify missing statements
|
||||
requiring tests to ensure coverage. This can be run by opening
|
||||
`./coverage/html/index.html`.
|
||||
|
||||
For example.
|
||||
|
||||
```zsh
|
||||
pytest --cov; open ./coverage/html/index.html
|
||||
```
|
||||
|
||||
??? info "HTML coverage report output"
|
||||
|
||||

|
||||
|
||||

|
||||
### Tests
|
||||
|
||||
See the [tests documentation](./TESTS.md) for information about running and writing tests.
|
||||
### Reloading Changes
|
||||
|
||||
Experimenting with changes to the Python source code is a drag if you have to re-start the server —
|
||||
@@ -167,6 +142,23 @@ and so you'll have access to the same python environment as the InvokeAI app.
|
||||
|
||||
This is _super_ handy.
|
||||
|
||||
#### Enabling Type-Checking with Pylance
|
||||
|
||||
We use python's typing system in InvokeAI. PR reviews will include checking that types are present and correct. We don't enforce types with `mypy` at this time, but that is on the horizon.
|
||||
|
||||
Using a code analysis tool to automatically type check your code (and types) is very important when writing with types. These tools provide immediate feedback in your editor when types are incorrect, and following their suggestions lead to fewer runtime bugs.
|
||||
|
||||
Pylance, installed at the beginning of this guide, is the de-facto python LSP (language server protocol). It provides type checking in the editor (among many other features). Once installed, you do need to enable type checking manually:
|
||||
|
||||
- Open a python file
|
||||
- Look along the status bar in VSCode for `{ } Python`
|
||||
- Click the `{ }`
|
||||
- Turn type checking on - basic is fine
|
||||
|
||||
You'll now see red squiggly lines where type issues are detected. Hover your cursor over the indicated symbols to see what's wrong.
|
||||
|
||||
In 99% of cases when the type checker says there is a problem, there really is a problem, and you should take some time to understand and resolve what it is pointing out.
|
||||
|
||||
#### Debugging configs with `launch.json`
|
||||
|
||||
Debugging configs are managed in a `launch.json` file. Like most VSCode configs,
|
||||
|
||||
1214
docs/contributing/MODEL_MANAGER.md
Normal file
1214
docs/contributing/MODEL_MANAGER.md
Normal file
File diff suppressed because it is too large
Load Diff
89
docs/contributing/TESTS.md
Normal file
89
docs/contributing/TESTS.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# InvokeAI Backend Tests
|
||||
|
||||
We use `pytest` to run the backend python tests. (See [pyproject.toml](/pyproject.toml) for the default `pytest` options.)
|
||||
|
||||
## Fast vs. Slow
|
||||
All tests are categorized as either 'fast' (no test annotation) or 'slow' (annotated with the `@pytest.mark.slow` decorator).
|
||||
|
||||
'Fast' tests are run to validate every PR, and are fast enough that they can be run routinely during development.
|
||||
|
||||
'Slow' tests are currently only run manually on an ad-hoc basis. In the future, they may be automated to run nightly. Most developers are only expected to run the 'slow' tests that directly relate to the feature(s) that they are working on.
|
||||
|
||||
As a rule of thumb, tests should be marked as 'slow' if there is a chance that they take >1s (e.g. on a CPU-only machine with slow internet connection). Common examples of slow tests are tests that depend on downloading a model, or running model inference.
|
||||
|
||||
## Running Tests
|
||||
|
||||
Below are some common test commands:
|
||||
```bash
|
||||
# Run the fast tests. (This implicitly uses the configured default option: `-m "not slow"`.)
|
||||
pytest tests/
|
||||
|
||||
# Equivalent command to run the fast tests.
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Run the slow tests.
|
||||
pytest tests/ -m "slow"
|
||||
|
||||
# Run the slow tests from a specific file.
|
||||
pytest tests/path/to/slow_test.py -m "slow"
|
||||
|
||||
# Run all tests (fast and slow).
|
||||
pytest tests -m ""
|
||||
```
|
||||
|
||||
## Test Organization
|
||||
|
||||
All backend tests are in the [`tests/`](/tests/) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
|
||||
|
||||
TODO: The above statement is aspirational. A re-organization of legacy tests is required to make it true.
|
||||
|
||||
## Tests that depend on models
|
||||
|
||||
There are a few things to keep in mind when adding tests that depend on models.
|
||||
|
||||
1. If a required model is not already present, it should automatically be downloaded as part of the test setup.
|
||||
2. If a model is already downloaded, it should not be re-downloaded unnecessarily.
|
||||
3. Take reasonable care to keep the total number of models required for the tests low. Whenever possible, re-use models that are already required for other tests. If you are adding a new model, consider including a comment to explain why it is required/unique.
|
||||
|
||||
There are several utilities to help with model setup for tests. Here is a sample test that depends on a model:
|
||||
```python
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model(model_installer, torch_device):
|
||||
model_info = install_and_load_model(
|
||||
model_installer=model_installer,
|
||||
model_path_id_or_url="HF/dummy_model_id",
|
||||
model_name="dummy_model",
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Dummy,
|
||||
)
|
||||
|
||||
dummy_input = build_dummy_input(torch_device)
|
||||
|
||||
with torch.no_grad(), model_info as model:
|
||||
model.to(torch_device, dtype=torch.float32)
|
||||
output = model(dummy_input)
|
||||
|
||||
# Validate output...
|
||||
|
||||
```
|
||||
|
||||
## Test Coverage
|
||||
|
||||
To review test coverage, append `--cov` to your pytest command:
|
||||
```bash
|
||||
pytest tests/ --cov
|
||||
```
|
||||
|
||||
Test outcomes and coverage will be reported in the terminal. In addition, a more detailed report is created in both XML and HTML format in the `./coverage` folder. The HTML output is particularly helpful in identifying untested statements where coverage should be improved. The HTML report can be viewed by opening `./coverage/html/index.html`.
|
||||
|
||||
??? info "HTML coverage report output"
|
||||
|
||||

|
||||
|
||||

|
||||
@@ -12,8 +12,9 @@ To get started, take a look at our [new contributors checklist](newContributorCh
|
||||
Once you're setup, for more information, you can review the documentation specific to your area of interest:
|
||||
|
||||
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
|
||||
* #### [Frontend Documentation](development_guides/contributingToFrontend.md)
|
||||
* #### [Frontend Documentation](./contributingToFrontend.md)
|
||||
* #### [Node Documentation](../INVOCATIONS.md)
|
||||
* #### [InvokeAI Model Manager](../MODEL_MANAGER.md)
|
||||
* #### [Local Development](../LOCAL_DEVELOPMENT.md)
|
||||
|
||||
|
||||
@@ -38,9 +39,9 @@ There are two paths to making a development contribution:
|
||||
|
||||
If you need help, you can ask questions in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord.
|
||||
|
||||
For frontend related work, **@pyschedelicious** is the best person to reach out to.
|
||||
For frontend related work, **@psychedelicious** is the best person to reach out to.
|
||||
|
||||
For backend related work, please reach out to **@blessedcoolant**, **@lstein**, **@StAlKeR7779** or **@pyschedelicious**.
|
||||
For backend related work, please reach out to **@blessedcoolant**, **@lstein**, **@StAlKeR7779** or **@psychedelicious**.
|
||||
|
||||
|
||||
## **What does the Code of Conduct mean for me?**
|
||||
|
||||
@@ -10,4 +10,4 @@ When updating or creating documentation, please keep in mind InvokeAI is a tool
|
||||
|
||||
## Help & Questions
|
||||
|
||||
Please ping @imic1 or @hipsterusername in the [Discord](https://discord.com/channels/1020123559063990373/1049495067846524939) if you have any questions.
|
||||
Please ping @imic or @hipsterusername in the [Discord](https://discord.com/channels/1020123559063990373/1049495067846524939) if you have any questions.
|
||||
@@ -159,7 +159,7 @@ groups in `invokeia.yaml`:
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
|
||||
@@ -207,11 +207,8 @@ if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory (not recommended)|
|
||||
| `model_config_db` | `auto` | Location of the model configuration database. Specify `auto` to use the main invokeai.db database, or specify a `.yaml` or `.db` file to store the data externally.|
|
||||
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||
@@ -234,6 +231,18 @@ Paths:
|
||||
# controlnet_dir: null
|
||||
```
|
||||
|
||||
### Model Cache
|
||||
|
||||
These options control the size of various caches that InvokeAI uses
|
||||
during the model loading and conversion process. All units are in GB
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `disk` | `20.0` | Before loading a model into memory, InvokeAI converts .ckpt and .safetensors models into diffusers format and saves them to disk. This option controls the maximum size of the directory in which these converted models are stored. If set to zero, then only the most recently-used model will be cached. |
|
||||
| `ram` | `6.0` | After loading a model from disk, it is kept in system RAM until it is needed again. This option controls how much RAM is set aside for this purpose. Larger amounts allow more models to reside in RAM and for InvokeAI to quickly switch between them. |
|
||||
| `vram` | `0.25` | This allows smaller models to remain in VRAM, speeding up execution modestly. It should be a small number. |
|
||||
|
||||
|
||||
### Logging
|
||||
|
||||
These settings control the information, warning, and debugging
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
---
|
||||
title: ControlNet
|
||||
title: Control Adapters
|
||||
---
|
||||
|
||||
# :material-loupe: ControlNet
|
||||
# :material-loupe: Control Adapters
|
||||
|
||||
## ControlNet
|
||||
|
||||
ControlNet
|
||||
|
||||
ControlNet is a powerful set of features developed by the open-source
|
||||
community (notably, Stanford researcher
|
||||
[**@ilyasviel**](https://github.com/lllyasviel)) that allows you to
|
||||
@@ -20,7 +18,7 @@ towards generating images that better fit your desired style or
|
||||
outcome.
|
||||
|
||||
|
||||
### How it works
|
||||
#### How it works
|
||||
|
||||
ControlNet works by analyzing an input image, pre-processing that
|
||||
image to identify relevant information that can be interpreted by each
|
||||
@@ -30,7 +28,7 @@ composition, or other aspects of the image to better achieve a
|
||||
specific result.
|
||||
|
||||
|
||||
### Models
|
||||
#### Models
|
||||
|
||||
InvokeAI provides access to a series of ControlNet models that provide
|
||||
different effects or styles in your generated images. Currently
|
||||
@@ -96,6 +94,8 @@ A model that generates normal maps from input images, allowing for more realisti
|
||||
**Image Segmentation**:
|
||||
A model that divides input images into segments or regions, each of which corresponds to a different object or part of the image. (More details coming soon)
|
||||
|
||||
**QR Code Monster**:
|
||||
A model that helps generate creative QR codes that still scan. Can also be used to create images with text, logos or shapes within them.
|
||||
|
||||
**Openpose**:
|
||||
The OpenPose control model allows for the identification of the general pose of a character by pre-processing an existing image with a clear human structure. With advanced options, Openpose can also detect the face or hands in the image.
|
||||
@@ -120,7 +120,7 @@ With Pix2Pix, you can input an image into the controlnet, and then "instruct" th
|
||||
Each of these models can be adjusted and combined with other ControlNet models to achieve different results, giving you even more control over your image generation process.
|
||||
|
||||
|
||||
## Using ControlNet
|
||||
### Using ControlNet
|
||||
|
||||
To use ControlNet, you can simply select the desired model and adjust both the ControlNet and Pre-processor settings to achieve the desired result. You can also use multiple ControlNet models at the same time, allowing you to achieve even more complex effects or styles in your generated images.
|
||||
|
||||
@@ -132,3 +132,31 @@ Weight - Strength of the Controlnet model applied to the generation for the sect
|
||||
Start/End - 0 represents the start of the generation, 1 represents the end. The Start/end setting controls what steps during the generation process have the ControlNet applied.
|
||||
|
||||
Additionally, each ControlNet section can be expanded in order to manipulate settings for the image pre-processor that adjusts your uploaded image before using it in when you Invoke.
|
||||
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
[IP-Adapter](https://ip-adapter.github.io) is a tooling that allows for image prompt capabilities with text-to-image diffusion models. IP-Adapter works by analyzing the given image prompt to extract features, then passing those features to the UNet along with any other conditioning provided.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
#### Installation
|
||||
There are several ways to install IP-Adapter models with an existing InvokeAI installation:
|
||||
|
||||
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [5] to download models.
|
||||
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
|
||||
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
|
||||
|
||||
#### Using IP-Adapter
|
||||
|
||||
IP-Adapter can be used by navigating to the *Control Adapters* options and enabling IP-Adapter.
|
||||
|
||||
IP-Adapter requires an image to be used as the Image Prompt. It can also be used in conjunction with text prompts, Image-to-Image, Inpainting, Outpainting, ControlNets and LoRAs.
|
||||
|
||||
|
||||
Each IP-Adapter has two settings that are applied to the IP-Adapter:
|
||||
|
||||
* Weight - Strength of the IP-Adapter model applied to the generation for the section, defined by start/end
|
||||
* Start/End - 0 represents the start of the generation, 1 represents the end. The Start/end setting controls what steps during the generation process have the IP-Adapter applied.
|
||||
|
||||
336
docs/features/UTILITIES.md
Normal file
336
docs/features/UTILITIES.md
Normal file
@@ -0,0 +1,336 @@
|
||||
---
|
||||
title: Command-line Utilities
|
||||
---
|
||||
|
||||
# :material-file-document: Utilities
|
||||
|
||||
# Command-line Utilities
|
||||
|
||||
InvokeAI comes with several scripts that are accessible via the
|
||||
command line. To access these commands, start the "developer's
|
||||
console" from the launcher (`invoke.bat` menu item [8]). Users who are
|
||||
familiar with Python can alternatively activate InvokeAI's virtual
|
||||
environment (typically, but not necessarily `invokeai/.venv`).
|
||||
|
||||
In the developer's console, type the script's name to run it. To get a
|
||||
synopsis of what a utility does and the command-line arguments it
|
||||
accepts, pass it the `-h` argument, e.g.
|
||||
|
||||
```bash
|
||||
invokeai-merge -h
|
||||
```
|
||||
## **invokeai-web**
|
||||
|
||||
This script launches the web server and is effectively identical to
|
||||
selecting option [1] in the launcher. An advantage of launching the
|
||||
server from the command line is that you can override any setting
|
||||
configuration option in `invokeai.yaml` using like-named command-line
|
||||
arguments. For example, to temporarily change the size of the RAM
|
||||
cache to 7 GB, you can launch as follows:
|
||||
|
||||
```bash
|
||||
invokeai-web --ram 7
|
||||
```
|
||||
|
||||
## **invokeai-merge**
|
||||
|
||||
This is the model merge script, the same as launcher option [4]. Call
|
||||
it with the `--gui` command-line argument to start the interactive
|
||||
console-based GUI. Alternatively, you can run it non-interactively
|
||||
using command-line arguments as illustrated in the example below which
|
||||
merges models named `stable-diffusion-1.5` and `inkdiffusion` into a new model named
|
||||
`my_new_model`:
|
||||
|
||||
```bash
|
||||
invokeai-merge --force --base-model sd-1 --models stable-diffusion-1.5 inkdiffusion --merged_model_name my_new_model
|
||||
```
|
||||
|
||||
## **invokeai-ti**
|
||||
|
||||
This is the textual inversion training script that is run by launcher
|
||||
option [3]. Call it with `--gui` to run the interactive console-based
|
||||
front end. It can also be run non-interactively. It has about a
|
||||
zillion arguments, but a typical training session can be launched
|
||||
with:
|
||||
|
||||
```bash
|
||||
invokeai-ti --model stable-diffusion-1.5 \
|
||||
--placeholder_token 'jello' \
|
||||
--learnable_property object \
|
||||
--num_train_epochs 50 \
|
||||
--train_data_dir /path/to/training/images \
|
||||
--output_dir /path/to/trained/model
|
||||
```
|
||||
|
||||
(Note that \\ is the Linux/Mac long-line continuation character. Use ^
|
||||
in Windows).
|
||||
|
||||
## **invokeai-install**
|
||||
|
||||
This is the console-based model install script that is run by launcher
|
||||
option [5]. If called without arguments, it will launch the
|
||||
interactive console-based interface. It can also be used
|
||||
non-interactively to list, add and remove models as shown by these
|
||||
examples:
|
||||
|
||||
* This will download and install three models from CivitAI, HuggingFace,
|
||||
and local disk:
|
||||
|
||||
```bash
|
||||
invokeai-install --add https://civitai.com/api/download/models/161302 ^
|
||||
gsdf/Counterfeit-V3.0 ^
|
||||
D:\Models\merge_model_two.safetensors
|
||||
```
|
||||
(Note that ^ is the Windows long-line continuation character. Use \\ on
|
||||
Linux/Mac).
|
||||
|
||||
* This will list installed models of type `main`:
|
||||
|
||||
```bash
|
||||
invokeai-model-install --list-models main
|
||||
```
|
||||
|
||||
* This will delete the models named `voxel-ish` and `realisticVision`:
|
||||
|
||||
```bash
|
||||
invokeai-model-install --delete voxel-ish realisticVision
|
||||
```
|
||||
|
||||
## **invokeai-configure**
|
||||
|
||||
This is the console-based configure script that ran when InvokeAI was
|
||||
first installed. You can run it again at any time to change the
|
||||
configuration, repair a broken install.
|
||||
|
||||
Called without any arguments, `invokeai-configure` enters interactive
|
||||
mode with two screens. The first screen is a form that provides access
|
||||
to most of InvokeAI's configuration options. The second screen lets
|
||||
you download, add, and delete models interactively. When you exit the
|
||||
second screen, the script will add any missing "support models"
|
||||
needed for core functionality, and any selected "sd weights" which are
|
||||
the model checkpoint/diffusers files.
|
||||
|
||||
This behavior can be changed via a series of command-line
|
||||
arguments. Here are some of the useful ones:
|
||||
|
||||
* `invokeai-configure --skip-sd-weights --skip-support-models`
|
||||
This will run just the configuration part of the utility, skipping
|
||||
downloading of support models and stable diffusion weights.
|
||||
|
||||
* `invokeai-configure --yes`
|
||||
This will run the configure script non-interactively. It will set the
|
||||
configuration options to their default values, install/repair support
|
||||
models, and download the "recommended" set of SD models.
|
||||
|
||||
* `invokeai-configure --yes --default_only`
|
||||
This will run the configure script non-interactively. In contrast to
|
||||
the previous command, it will only download the default SD model,
|
||||
Stable Diffusion v1.5
|
||||
|
||||
* `invokeai-configure --yes --default_only --skip-sd-weights`
|
||||
This is similar to the previous command, but will not download any
|
||||
SD models at all. It is usually used to repair a broken install.
|
||||
|
||||
By default, `invokeai-configure` runs on the currently active InvokeAI
|
||||
root folder. To run it against a different root, pass it the `--root
|
||||
</path/to/root>` argument.
|
||||
|
||||
Lastly, you can use `invokeai-configure` to create a working root
|
||||
directory entirely from scratch. Assuming you wish to make a root directory
|
||||
named `InvokeAI-New`, run this command:
|
||||
|
||||
```bash
|
||||
invokeai-configure --root InvokeAI-New --yes --default_only
|
||||
```
|
||||
This will create a minimally functional root directory. You can now
|
||||
launch the web server against it with `invokeai-web --root InvokeAI-New`.
|
||||
|
||||
## **invokeai-update**
|
||||
|
||||
This is the interactive console-based script that is run by launcher
|
||||
menu item [9] to update to a new version of InvokeAI. It takes no
|
||||
command-line arguments.
|
||||
|
||||
## **invokeai-metadata**
|
||||
|
||||
This is a script which takes a list of InvokeAI-generated images and
|
||||
outputs their metadata in the same JSON format that you get from the
|
||||
`</>` button in the Web GUI. For example:
|
||||
|
||||
```bash
|
||||
$ invokeai-metadata ffe2a115-b492-493c-afff-7679aa034b50.png
|
||||
ffe2a115-b492-493c-afff-7679aa034b50.png:
|
||||
{
|
||||
"app_version": "3.1.0",
|
||||
"cfg_scale": 8.0,
|
||||
"clip_skip": 0,
|
||||
"controlnets": [],
|
||||
"generation_mode": "sdxl_txt2img",
|
||||
"height": 1024,
|
||||
"loras": [],
|
||||
"model": {
|
||||
"base_model": "sdxl",
|
||||
"model_name": "stable-diffusion-xl-base-1.0",
|
||||
"model_type": "main"
|
||||
},
|
||||
"negative_prompt": "",
|
||||
"negative_style_prompt": "",
|
||||
"positive_prompt": "military grade sushi dinner for shock troopers",
|
||||
"positive_style_prompt": "",
|
||||
"rand_device": "cpu",
|
||||
"refiner_cfg_scale": 7.5,
|
||||
"refiner_model": {
|
||||
"base_model": "sdxl-refiner",
|
||||
"model_name": "sd_xl_refiner_1.0",
|
||||
"model_type": "main"
|
||||
},
|
||||
"refiner_negative_aesthetic_score": 2.5,
|
||||
"refiner_positive_aesthetic_score": 6.0,
|
||||
"refiner_scheduler": "euler",
|
||||
"refiner_start": 0.8,
|
||||
"refiner_steps": 20,
|
||||
"scheduler": "euler",
|
||||
"seed": 387129902,
|
||||
"steps": 25,
|
||||
"width": 1024
|
||||
}
|
||||
```
|
||||
|
||||
You may list multiple files on the command line.
|
||||
|
||||
## **invokeai-import-images**
|
||||
|
||||
InvokeAI uses a database to store information about images it
|
||||
generated, and just copying the image files from one InvokeAI root
|
||||
directory to another does not automatically import those images into
|
||||
the destination's gallery. This script allows you to bulk import
|
||||
images generated by one instance of InvokeAI into a gallery maintained
|
||||
by another. It also works on images generated by older versions of
|
||||
InvokeAI, going way back to version 1.
|
||||
|
||||
This script has an interactive mode only. The following example shows
|
||||
it in action:
|
||||
|
||||
```bash
|
||||
$ invokeai-import-images
|
||||
===============================================================================
|
||||
This script will import images generated by earlier versions of
|
||||
InvokeAI into the currently installed root directory:
|
||||
/home/XXXX/invokeai-main
|
||||
If this is not what you want to do, type ctrl-C now to cancel.
|
||||
===============================================================================
|
||||
= Configuration & Settings
|
||||
Found invokeai.yaml file at /home/XXXX/invokeai-main/invokeai.yaml:
|
||||
Database : /home/XXXX/invokeai-main/databases/invokeai.db
|
||||
Outputs : /home/XXXX/invokeai-main/outputs/images
|
||||
|
||||
Use these paths for import (yes) or choose different ones (no) [Yn]:
|
||||
Inputs: Specify absolute path containing InvokeAI .png images to import: /home/XXXX/invokeai-2.3/outputs/images/
|
||||
Include files from subfolders recursively [yN]?
|
||||
|
||||
Options for board selection for imported images:
|
||||
1) Select an existing board name. (found 4)
|
||||
2) Specify a board name to create/add to.
|
||||
3) Create/add to board named 'IMPORT'.
|
||||
4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_20230919T203519Z).
|
||||
5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5).
|
||||
Specify desired board option: 3
|
||||
|
||||
===============================================================================
|
||||
= Import Settings Confirmation
|
||||
|
||||
Database File Path : /home/XXXX/invokeai-main/databases/invokeai.db
|
||||
Outputs/Images Directory : /home/XXXX/invokeai-main/outputs/images
|
||||
Import Image Source Directory : /home/XXXX/invokeai-2.3/outputs/images/
|
||||
Recurse Source SubDirectories : No
|
||||
Count of .png file(s) found : 5785
|
||||
Board name option specified : IMPORT
|
||||
Database backup will be taken at : /home/XXXX/invokeai-main/databases/backup
|
||||
|
||||
Notes about the import process:
|
||||
- Source image files will not be modified, only copied to the outputs directory.
|
||||
- If the same file name already exists in the destination, the file will be skipped.
|
||||
- If the same file name already has a record in the database, the file will be skipped.
|
||||
- Invoke AI metadata tags will be updated/written into the imported copy only.
|
||||
- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)
|
||||
- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer.
|
||||
- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder.
|
||||
|
||||
Do you wish to continue with the import [Yn] ?
|
||||
|
||||
Making DB Backup at /home/lstein/invokeai-main/databases/backup/backup-20230919T203519Z-invokeai.db...Done!
|
||||
|
||||
===============================================================================
|
||||
Importing /home/XXXX/invokeai-2.3/outputs/images/17d09907-297d-4db3-a18a-60b337feac66.png
|
||||
... (5785 more lines) ...
|
||||
===============================================================================
|
||||
= Import Complete - Elpased Time: 0.28 second(s)
|
||||
|
||||
Source File(s) : 5785
|
||||
Total Imported : 5783
|
||||
Skipped b/c file already exists on disk : 1
|
||||
Skipped b/c file already exists in db : 0
|
||||
Errors during import : 1
|
||||
```
|
||||
## **invokeai-db-maintenance**
|
||||
|
||||
This script helps maintain the integrity of your InvokeAI database by
|
||||
finding and fixing three problems that can arise over time:
|
||||
|
||||
1. An image was manually deleted from the outputs directory, leaving a
|
||||
dangling image record in the InvokeAI database. This will cause a
|
||||
black image to appear in the gallery. This is an "orphaned database
|
||||
image record." The script can fix this by running a "clean"
|
||||
operation on the database, removing the orphaned entries.
|
||||
|
||||
2. An image is present in the outputs directory but there is no
|
||||
corresponding entry in the database. This can happen when the image
|
||||
is added manually to the outputs directory, or if a crash occurred
|
||||
after the image was generated but before the database was
|
||||
completely updated. The symptom is that the image is present in the
|
||||
outputs folder but doesn't appear in the InvokeAI gallery. This is
|
||||
called an "orphaned image file." The script can fix this problem by
|
||||
running an "archive" operation in which orphaned files are moved
|
||||
into a directory named `outputs/images-archive`. If you wish, you
|
||||
can then run `invokeai-image-import` to reimport these images back
|
||||
into the database.
|
||||
|
||||
3. The thumbnail for an image is missing, again causing a black
|
||||
gallery thumbnail. This is fixed by running the "thumbnaiils"
|
||||
operation, which simply regenerates and re-registers the missing
|
||||
thumbnail.
|
||||
|
||||
You can find and fix all three of these problems in a single go by
|
||||
executing this command:
|
||||
|
||||
```bash
|
||||
invokeai-db-maintenance --operation all
|
||||
```
|
||||
|
||||
Or you can run just the clean and thumbnail operations like this:
|
||||
|
||||
```bash
|
||||
invokeai-db-maintenance -operation clean, thumbnail
|
||||
```
|
||||
|
||||
If called without any arguments, the script will ask you which
|
||||
operations you wish to perform.
|
||||
|
||||
## **invokeai-migrate3**
|
||||
|
||||
This script will migrate settings and models (but not images!) from an
|
||||
InvokeAI v2.3 root folder to an InvokeAI 3.X folder. Call it with the
|
||||
source and destination root folders like this:
|
||||
|
||||
```bash
|
||||
invokeai-migrate3 --from ~/invokeai-2.3 --to invokeai-3.1.1
|
||||
```
|
||||
|
||||
Both directories must previously have been properly created and
|
||||
initialized by `invokeai-configure`. If you wish to migrate the images
|
||||
contained in the older root as well, you can use the
|
||||
`invokeai-image-migrate` script described earlier.
|
||||
|
||||
---
|
||||
|
||||
Copyright (c) 2023, Lincoln Stein and the InvokeAI Development Team
|
||||
@@ -51,6 +51,9 @@ Prevent InvokeAI from displaying unwanted racy images.
|
||||
### * [Controlling Logging](LOGGING.md)
|
||||
Control how InvokeAI logs status messages.
|
||||
|
||||
### * [Command-line Utilities](UTILITIES.md)
|
||||
A list of the command-line utilities available with InvokeAI.
|
||||
|
||||
<!-- OUT OF DATE
|
||||
### * [Miscellaneous](OTHER.md)
|
||||
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
||||
|
||||
@@ -147,6 +147,7 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
### InvokeAI Configuration
|
||||
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
|
||||
- [Database Maintenance and other Command Line Utilities](features/UTILITIES.md)
|
||||
|
||||
## :octicons-log-16: Important Changes Since Version 2.3
|
||||
|
||||
|
||||
@@ -256,6 +256,10 @@ manager, please follow these steps:
|
||||
*highly recommended** if your virtual environment is located outside of
|
||||
your runtime directory.
|
||||
|
||||
!!! tip
|
||||
|
||||
On linux, it is recommended to run invokeai with the following env var: `MALLOC_MMAP_THRESHOLD_=1048576`. For example: `MALLOC_MMAP_THRESHOLD_=1048576 invokeai --web`. This helps to prevent memory fragmentation that can lead to memory accumulation over time. This env var is set automatically when running via `invoke.sh`.
|
||||
|
||||
10. Render away!
|
||||
|
||||
Browse the [features](../features/index.md) section to learn about all the
|
||||
@@ -296,8 +300,18 @@ code for InvokeAI. For this to work, you will need to install the
|
||||
on your system, please see the [Git Installation
|
||||
Guide](https://github.com/git-guides/install-git)
|
||||
|
||||
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md).
|
||||
|
||||
If you have a "normal" installation, you should create a totally separate virtual environment for the git-based installation, else the two may interfere.
|
||||
|
||||
> **Why do I need the frontend toolchain**?
|
||||
>
|
||||
> The InvokeAI project uses trunk-based development. That means our `main` branch is the development branch, and releases are tags on that branch. Because development is very active, we don't keep an updated build of the UI in `main` - we only build it for production releases.
|
||||
>
|
||||
> That means that between releases, to have a functioning application when running directly from the repo, you will need to run the UI in dev mode or build it regularly (any time the UI code changes).
|
||||
|
||||
1. Create a fork of the InvokeAI repository through the GitHub UI or [this link](https://github.com/invoke-ai/InvokeAI/fork)
|
||||
1. From the command line, run this command:
|
||||
2. From the command line, run this command:
|
||||
```bash
|
||||
git clone https://github.com/<your_github_username>/InvokeAI.git
|
||||
```
|
||||
@@ -305,10 +319,10 @@ Guide](https://github.com/git-guides/install-git)
|
||||
This will create a directory named `InvokeAI` and populate it with the
|
||||
full source code from your fork of the InvokeAI repository.
|
||||
|
||||
2. Activate the InvokeAI virtual environment as per step (4) of the manual
|
||||
3. Activate the InvokeAI virtual environment as per step (4) of the manual
|
||||
installation protocol (important!)
|
||||
|
||||
3. Enter the InvokeAI repository directory and run one of these
|
||||
4. Enter the InvokeAI repository directory and run one of these
|
||||
commands, based on your GPU:
|
||||
|
||||
=== "CUDA (NVidia)"
|
||||
@@ -334,11 +348,15 @@ installation protocol (important!)
|
||||
Be sure to pass `-e` (for an editable install) and don't forget the
|
||||
dot ("."). It is part of the command.
|
||||
|
||||
You can now run `invokeai` and its related commands. The code will be
|
||||
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md) and do a production build of the UI as described.
|
||||
|
||||
6. You can now run `invokeai` and its related commands. The code will be
|
||||
read from the repository, so that you can edit the .py source files
|
||||
and watch the code's behavior change.
|
||||
|
||||
4. If you wish to contribute to the InvokeAI project, you are
|
||||
When you pull in new changes to the repo, be sure to re-build the UI.
|
||||
|
||||
7. If you wish to contribute to the InvokeAI project, you are
|
||||
encouraged to establish a GitHub account and "fork"
|
||||
https://github.com/invoke-ai/InvokeAI into your own copy of the
|
||||
repository. You can then use GitHub functions to create and submit
|
||||
|
||||
@@ -123,11 +123,20 @@ installation. Examples:
|
||||
# (list all controlnet models)
|
||||
invokeai-model-install --list controlnet
|
||||
|
||||
# (install the model at the indicated URL)
|
||||
# (install the diffusers model using its hugging face repo_id)
|
||||
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0
|
||||
|
||||
# (install a diffusers model that lives in a subfolder)
|
||||
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0:vae
|
||||
|
||||
# (install the checkpoint model at the indicated URL)
|
||||
invokeai-model-install --add https://civitai.com/api/download/models/128713
|
||||
|
||||
# (delete the named model)
|
||||
invokeai-model-install --delete sd-1/main/analog-diffusion
|
||||
# (delete the named model if its name is unique)
|
||||
invokeai-model-install --delete analog-diffusion
|
||||
|
||||
# (delete the named model using its fully qualified name)
|
||||
invokeai-model-install --delete sd-1/main/test_model
|
||||
```
|
||||
|
||||
### Installation via the Web GUI
|
||||
@@ -141,6 +150,24 @@ left-hand panel) and navigate to *Import Models*
|
||||
wish to install. You may use a URL, HuggingFace repo id, or a path on
|
||||
your local disk.
|
||||
|
||||
There is special scanning for CivitAI URLs which lets
|
||||
you cut-and-paste either the URL for a CivitAI model page
|
||||
(e.g. https://civitai.com/models/12345), or the direct download link
|
||||
for a model (e.g. https://civitai.com/api/download/models/12345).
|
||||
|
||||
If the desired model is a HuggingFace diffusers model that is located
|
||||
in a subfolder of the repository (e.g. vae), then append the subfolder
|
||||
to the end of the repo_id like this:
|
||||
|
||||
```
|
||||
# a VAE model located in subfolder "vae"
|
||||
stabilityai/stable-diffusion-xl-base-1.0:vae
|
||||
|
||||
# version 2 of the model located in subfolder "v2"
|
||||
monster-labs/control_v1p_sd15_qrcode_monster:v2
|
||||
|
||||
```
|
||||
|
||||
3. Alternatively, the *Scan for Models* button allows you to paste in
|
||||
the path to a folder somewhere on your machine. It will be scanned for
|
||||
importable models and prompt you to add the ones of your choice.
|
||||
@@ -171,3 +198,16 @@ subfolders and organize them as you wish.
|
||||
|
||||
The location of the autoimport directories are controlled by settings
|
||||
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
|
||||
|
||||
### Installing models that live in HuggingFace subfolders
|
||||
|
||||
On rare occasions you may need to install a diffusers-style model that
|
||||
lives in a subfolder of a HuggingFace repo id. In this event, simply
|
||||
add ":_subfolder-name_" to the end of the repo id. For example, if the
|
||||
repo id is "monster-labs/control_v1p_sd15_qrcode_monster" and the model
|
||||
you wish to fetch lives in a subfolder named "v2", then the repo id to
|
||||
pass to the various model installers should be
|
||||
|
||||
```
|
||||
monster-labs/control_v1p_sd15_qrcode_monster:v2
|
||||
```
|
||||
|
||||
@@ -4,12 +4,12 @@ The workflow editor is a blank canvas allowing for the use of individual functio
|
||||
|
||||
If you're not familiar with Diffusion, take a look at our [Diffusion Overview.](../help/diffusion.md) Understanding how diffusion works will enable you to more easily use the Workflow Editor and build workflows to suit your needs.
|
||||
|
||||
## UI Features
|
||||
## Features
|
||||
|
||||
### Linear View
|
||||
The Workflow Editor allows you to create a UI for your workflow, to make it easier to iterate on your generations.
|
||||
|
||||
To add an input to the Linear UI, right click on the input and select "Add to Linear View".
|
||||
To add an input to the Linear UI, right click on the input label and select "Add to Linear View".
|
||||
|
||||
The Linear UI View will also be part of the saved workflow, allowing you share workflows and enable other to use them, regardless of complexity.
|
||||
|
||||
@@ -25,6 +25,10 @@ Any node or input field can be renamed in the workflow editor. If the input fiel
|
||||
* Backspace/Delete to delete a node
|
||||
* Shift+Click to drag and select multiple nodes
|
||||
|
||||
### Node Caching
|
||||
|
||||
Nodes have a "Use Cache" option in their footer. This allows for performance improvements by using the previously cached values during the workflow processing.
|
||||
|
||||
|
||||
## Important Concepts
|
||||
|
||||
|
||||
@@ -8,19 +8,21 @@ To download a node, simply download the `.py` node file from the link and add it
|
||||
|
||||
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
||||
|
||||
## Community Nodes
|
||||
--------------------------------
|
||||
|
||||
### FaceTools
|
||||
--------------------------------
|
||||
### Make 3D
|
||||
|
||||
**Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
|
||||
**Description:** Create compelling 3D stereo images from 2D originals.
|
||||
|
||||
**Node Link:** https://github.com/ymgenesis/FaceTools/
|
||||
**Node Link:** [https://gitlab.com/srcrr/shift3d/-/raw/main/make3d.py](https://gitlab.com/srcrr/shift3d)
|
||||
|
||||
**FaceMask Output Examples**
|
||||
**Example Node Graph:** https://gitlab.com/srcrr/shift3d/-/raw/main/example-workflow.json?ref_type=heads&inline=false
|
||||
|
||||

|
||||

|
||||

|
||||
**Output Examples**
|
||||
|
||||
{: style="height:512px;width:512px"}
|
||||
{: style="height:512px;width:512px"}
|
||||
|
||||
--------------------------------
|
||||
### Ideal Size
|
||||
@@ -43,6 +45,52 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/image-picker-node
|
||||
|
||||
--------------------------------
|
||||
### Thresholding
|
||||
|
||||
**Description:** This node generates masks for highlights, midtones, and shadows given an input image. You can optionally specify a blur for the lookup table used in making those masks from the source image.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/thresholding-node
|
||||
|
||||
**Examples**
|
||||
|
||||
Input:
|
||||
|
||||
{: style="height:512px;width:512px"}
|
||||
|
||||
Highlights/Midtones/Shadows:
|
||||
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/727021c1-36ff-4ec8-90c8-105e00de986d" style="width: 30%" />
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0b721bfc-f051-404e-b905-2f16b824ddfe" style="width: 30%" />
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/04c1297f-1c88-42b6-a7df-dd090b976286" style="width: 30%" />
|
||||
|
||||
Highlights/Midtones/Shadows (with LUT blur enabled):
|
||||
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/19aa718a-70c1-4668-8169-d68f4bd13771" style="width: 30%" />
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0a440e43-697f-4d17-82ee-f287467df0a5" style="width: 30%" />
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0701fd0f-2ca7-4fe2-8613-2b52547bafce" style="width: 30%" />
|
||||
|
||||
--------------------------------
|
||||
### Halftone
|
||||
|
||||
**Description**: Halftone converts the source image to grayscale and then performs halftoning. CMYK Halftone converts the image to CMYK and applies a per-channel halftoning to make the source image look like a magazine or newspaper. For both nodes, you can specify angles and halftone dot spacing.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/halftone-node
|
||||
|
||||
**Example**
|
||||
|
||||
Input:
|
||||
|
||||
{: style="height:512px;width:512px"}
|
||||
|
||||
Halftone Output:
|
||||
|
||||
{: style="height:512px;width:512px"}
|
||||
|
||||
CMYK Halftone Output:
|
||||
|
||||
{: style="height:512px;width:512px"}
|
||||
|
||||
--------------------------------
|
||||
### Retroize
|
||||
|
||||
@@ -77,7 +125,7 @@ Generated Prompt: An enchanted weapon will be usable by any character regardless
|
||||
**Example Node Graph:** https://github.com/helix4u/load_video_frame/blob/main/Example_Workflow.json
|
||||
|
||||
**Output Example:**
|
||||
=======
|
||||
|
||||

|
||||
[Full mp4 of Example Output test.mp4](https://github.com/helix4u/load_video_frame/blob/main/test.mp4)
|
||||
|
||||
@@ -121,18 +169,6 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Enhance Image (simple adjustments)
|
||||
|
||||
**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||
|
||||
Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/image-enhance-node
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Generative Grammar-Based Prompt Nodes
|
||||
|
||||
@@ -153,16 +189,28 @@ This includes 3 Nodes:
|
||||
|
||||
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
||||
|
||||
This includes 4 Nodes:
|
||||
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||
This includes 15 Nodes:
|
||||
|
||||
- *Adjust Image Hue Plus* - Rotate the hue of an image in one of several different color spaces.
|
||||
- *Blend Latents/Noise (Masked)* - Use a mask to blend part of one latents tensor [including Noise outputs] into another. Can be used to "renoise" sections during a multi-stage [masked] denoising process.
|
||||
- *Enhance Image* - Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||
- *Equivalent Achromatic Lightness* - Calculates image lightness accounting for Helmholtz-Kohlrausch effect based on a method described by High, Green, and Nussbaum (2023).
|
||||
- *Text to Mask (Clipseg)* - Input a prompt and an image to generate a mask representing areas of the image matched by the prompt.
|
||||
- *Text to Mask Advanced (Clipseg)* - Output up to four prompt masks combined with logical "and", logical "or", or as separate channels of an RGBA image.
|
||||
- *Image Layer Blend* - Perform a layered blend of two images using alpha compositing. Opacity of top layer is selectable, with optional mask and several different blend modes/color spaces.
|
||||
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
|
||||
- *Image Dilate or Erode* - Dilate or expand a mask (or any image!). This is equivalent to an expand/contract operation.
|
||||
- *Image Value Thresholds* - Clip an image to pure black/white beyond specified thresholds.
|
||||
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
|
||||
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
|
||||
- *Rotate/Flip Image* - Rotate an image in degrees clockwise/counterclockwise about its center, optionally resizing the image boundaries to fit, or flipping it about the vertical and/or horizontal axes.
|
||||
- *Shadows/Highlights/Midtones* - Extract three masks (with adjustable hard or soft thresholds) representing shadows, midtones, and highlights regions of an image.
|
||||
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/composition-nodes
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
**Nodes and Output Examples:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Size Stepper Nodes
|
||||
@@ -196,6 +244,70 @@ Results after using the depth controlnet
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Prompt Tools
|
||||
|
||||
**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These where written to accompany the PromptsFromFile node and other prompt generation nodes.
|
||||
|
||||
1. PromptJoin - Joins to prompts into one.
|
||||
2. PromptReplace - performs a search and replace on a prompt. With the option of using regex.
|
||||
3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative.
|
||||
4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
|
||||
5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file.
|
||||
6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node.
|
||||
7. PromptJoinThree - Joins 3 prompt together.
|
||||
8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel.
|
||||
9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node.
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
|
||||
|
||||
--------------------------------
|
||||
|
||||
### XY Image to Grid and Images to Grids nodes
|
||||
|
||||
**Description:** Image to grid nodes and supporting tools.
|
||||
|
||||
1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then mutilple grids will be created until it runs out of images.
|
||||
2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporoting nodes. See example node setups for more details.
|
||||
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Image to Character Art Image Node's
|
||||
|
||||
**Description:** Group of nodes to convert an input image into ascii/unicode art Image
|
||||
|
||||
**Node Link:** https://github.com/mickr777/imagetoasciiimage
|
||||
|
||||
**Output Examples**
|
||||
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/115216705/8e061fcc-9a2c-4fa9-bcc7-c0f7b01e9056" width="300" />
|
||||
<img src="https://github.com/mickr777/imagetoasciiimage/assets/115216705/3c4990eb-2f42-46b9-90f9-0088b939dc6a" width="300" /></br>
|
||||
<img src="https://github.com/mickr777/imagetoasciiimage/assets/115216705/fee7f800-a4a8-41e2-a66b-c66e4343307e" width="300" />
|
||||
<img src="https://github.com/mickr777/imagetoasciiimage/assets/115216705/1d9c1003-a45f-45c2-aac7-46470bb89330" width="300" />
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Grid to Gif
|
||||
|
||||
**Description:** One node that turns a grid image into an image colletion, one node that turns an image collection into a gif
|
||||
|
||||
**Node Link:** https://github.com/mildmisery/invokeai-GridToGifNode/blob/main/GridToGif.py
|
||||
|
||||
**Example Node Graph:** https://github.com/mildmisery/invokeai-GridToGifNode/blob/main/Grid%20to%20Gif%20Example%20Workflow.json
|
||||
|
||||
**Output Examples**
|
||||
|
||||
<img src="https://raw.githubusercontent.com/mildmisery/invokeai-GridToGifNode/main/input.png" width="300" />
|
||||
<img src="https://raw.githubusercontent.com/mildmisery/invokeai-GridToGifNode/main/output.gif" width="300" />
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Example Node Template
|
||||
|
||||
**Description:** This node allows you to do super cool things with InvokeAI.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# List of Default Nodes
|
||||
|
||||
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
|
||||
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
|
||||
|
||||
| Node <img width=160 align="right"> | Function |
|
||||
|: ---------------------------------- | :--------------------------------------------------------------------------------------|
|
||||
@@ -17,11 +17,12 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|
||||
|Conditioning Primitive | A conditioning tensor primitive value|
|
||||
|Content Shuffle Processor | Applies content shuffle processing to image|
|
||||
|ControlNet | Collects ControlNet info to pass to other nodes|
|
||||
|OpenCV Inpaint | Simple inpaint using opencv.|
|
||||
|Denoise Latents | Denoises noisy latents to decodable images|
|
||||
|Divide Integers | Divides two numbers|
|
||||
|Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator|
|
||||
|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.|
|
||||
|[FaceMask](./detailedNodes/faceTools.md#facemask) | Generates masks for faces in an image to use with Inpainting|
|
||||
|[FaceIdentifier](./detailedNodes/faceTools.md#faceidentifier) | Identifies and labels faces in an image|
|
||||
|[FaceOff](./detailedNodes/faceTools.md#faceoff) | Creates a new image that is a scaled bounding box with a mask on the face for Inpainting|
|
||||
|Float Math | Perform basic math operations on two floats|
|
||||
|Float Primitive Collection | A collection of float primitive values|
|
||||
|Float Primitive | A float primitive value|
|
||||
@@ -76,6 +77,7 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|
||||
|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.|
|
||||
|ONNX Text to Latents | Generates latents from conditionings.|
|
||||
|ONNX Model Loader | Loads a main model, outputting its submodels.|
|
||||
|OpenCV Inpaint | Simple inpaint using opencv.|
|
||||
|Openpose Processor | Applies Openpose processing to image|
|
||||
|PIDI Processor | Applies PIDI processing to image|
|
||||
|Prompts from File | Loads prompts from a text file|
|
||||
@@ -97,5 +99,6 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|
||||
|String Primitive | A string primitive value|
|
||||
|Subtract Integers | Subtracts two numbers|
|
||||
|Tile Resample Processor | Tile resampler processor|
|
||||
|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.|
|
||||
|VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput|
|
||||
|Zoe (Depth) Processor | Applies Zoe depth processing to image|
|
||||
154
docs/nodes/detailedNodes/faceTools.md
Normal file
154
docs/nodes/detailedNodes/faceTools.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# Face Nodes
|
||||
|
||||
## FaceOff
|
||||
|
||||
FaceOff mimics a user finding a face in an image and resizing the bounding box
|
||||
around the head in Canvas.
|
||||
|
||||
Enter a face ID (found with FaceIdentifier) to choose which face to mask.
|
||||
|
||||
Just as you would add more context inside the bounding box by making it larger
|
||||
in Canvas, the node gives you a padding input (in pixels) which will
|
||||
simultaneously add more context, and increase the resolution of the bounding box
|
||||
so the face remains the same size inside it.
|
||||
|
||||
The "Minimum Confidence" input defaults to 0.5 (50%), and represents a pass/fail
|
||||
threshold a detected face must reach for it to be processed. Lowering this value
|
||||
may help if detection is failing. If the detected masks are imperfect and stray
|
||||
too far outside/inside of faces, the node gives you X & Y offsets to shrink/grow
|
||||
the masks by a multiplier.
|
||||
|
||||
FaceOff will output the face in a bounded image, taking the face off of the
|
||||
original image for input into any node that accepts image inputs. The node also
|
||||
outputs a face mask with the dimensions of the bounded image. The X & Y outputs
|
||||
are for connecting to the X & Y inputs of the Paste Image node, which will place
|
||||
the bounded image back on the original image using these coordinates.
|
||||
|
||||
###### Inputs/Outputs
|
||||
|
||||
| Input | Description |
|
||||
| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Image | Image for face detection |
|
||||
| Face ID | The face ID to process, numbered from 0. Multiple faces not supported. Find a face's ID with FaceIdentifier node. |
|
||||
| Minimum Confidence | Minimum confidence for face detection (lower if detection is failing) |
|
||||
| X Offset | X-axis offset of the mask |
|
||||
| Y Offset | Y-axis offset of the mask |
|
||||
| Padding | All-axis padding around the mask in pixels |
|
||||
| Chunk | Chunk (or divide) the image into sections to greatly improve face detection success. Defaults to off, but will activate if no faces are detected normally. Activate to chunk by default. |
|
||||
|
||||
| Output | Description |
|
||||
| ------------- | ------------------------------------------------ |
|
||||
| Bounded Image | Original image bound, cropped, and resized |
|
||||
| Width | The width of the bounded image in pixels |
|
||||
| Height | The height of the bounded image in pixels |
|
||||
| Mask | The output mask |
|
||||
| X | The x coordinate of the bounding box's left side |
|
||||
| Y | The y coordinate of the bounding box's top side |
|
||||
|
||||
## FaceMask
|
||||
|
||||
FaceMask mimics a user drawing masks on faces in an image in Canvas.
|
||||
|
||||
The "Face IDs" input allows the user to select specific faces to be masked.
|
||||
Leave empty to detect and mask all faces, or a comma-separated list for a
|
||||
specific combination of faces (ex: `1,2,4`). A single integer will detect and
|
||||
mask that specific face. Find face IDs with the FaceIdentifier node.
|
||||
|
||||
The "Minimum Confidence" input defaults to 0.5 (50%), and represents a pass/fail
|
||||
threshold a detected face must reach for it to be processed. Lowering this value
|
||||
may help if detection is failing.
|
||||
|
||||
If the detected masks are imperfect and stray too far outside/inside of faces,
|
||||
the node gives you X & Y offsets to shrink/grow the masks by a multiplier. All
|
||||
masks shrink/grow together by the X & Y offset values.
|
||||
|
||||
By default, masks are created to change faces. When masks are inverted, they
|
||||
change surrounding areas, protecting faces.
|
||||
|
||||
###### Inputs/Outputs
|
||||
|
||||
| Input | Description |
|
||||
| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Image | Image for face detection |
|
||||
| Face IDs | Comma-separated list of face ids to mask eg '0,2,7'. Numbered from 0. Leave empty to mask all. Find face IDs with FaceIdentifier node. |
|
||||
| Minimum Confidence | Minimum confidence for face detection (lower if detection is failing) |
|
||||
| X Offset | X-axis offset of the mask |
|
||||
| Y Offset | Y-axis offset of the mask |
|
||||
| Chunk | Chunk (or divide) the image into sections to greatly improve face detection success. Defaults to off, but will activate if no faces are detected normally. Activate to chunk by default. |
|
||||
| Invert Mask | Toggle to invert the face mask |
|
||||
|
||||
| Output | Description |
|
||||
| ------ | --------------------------------- |
|
||||
| Image | The original image |
|
||||
| Width | The width of the image in pixels |
|
||||
| Height | The height of the image in pixels |
|
||||
| Mask | The output face mask |
|
||||
|
||||
## FaceIdentifier
|
||||
|
||||
FaceIdentifier outputs an image with detected face IDs printed in white numbers
|
||||
onto each face.
|
||||
|
||||
Face IDs can then be used in FaceMask and FaceOff to selectively mask all, a
|
||||
specific combination, or single faces.
|
||||
|
||||
The FaceIdentifier output image is generated for user reference, and isn't meant
|
||||
to be passed on to other image-processing nodes.
|
||||
|
||||
The "Minimum Confidence" input defaults to 0.5 (50%), and represents a pass/fail
|
||||
threshold a detected face must reach for it to be processed. Lowering this value
|
||||
may help if detection is failing. If an image is changed in the slightest, run
|
||||
it through FaceIdentifier again to get updated FaceIDs.
|
||||
|
||||
###### Inputs/Outputs
|
||||
|
||||
| Input | Description |
|
||||
| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Image | Image for face detection |
|
||||
| Minimum Confidence | Minimum confidence for face detection (lower if detection is failing) |
|
||||
| Chunk | Chunk (or divide) the image into sections to greatly improve face detection success. Defaults to off, but will activate if no faces are detected normally. Activate to chunk by default. |
|
||||
|
||||
| Output | Description |
|
||||
| ------ | ------------------------------------------------------------------------------------------------ |
|
||||
| Image | The original image with small face ID numbers printed in white onto each face for user reference |
|
||||
| Width | The width of the original image in pixels |
|
||||
| Height | The height of the original image in pixels |
|
||||
|
||||
## Tips
|
||||
|
||||
- If not all target faces are being detected, activate Chunk to bypass full
|
||||
image face detection and greatly improve detection success.
|
||||
- Final results will vary between full-image detection and chunking for faces
|
||||
that are detectable by both due to the nature of the process. Try either to
|
||||
your taste.
|
||||
- Be sure Minimum Confidence is set the same when using FaceIdentifier with
|
||||
FaceOff/FaceMask.
|
||||
- For FaceOff, use the color correction node before faceplace to correct edges
|
||||
being noticeable in the final image (see example screenshot).
|
||||
- Non-inpainting models may struggle to paint/generate correctly around faces.
|
||||
- If your face won't change the way you want it to no matter what you change,
|
||||
consider that the change you're trying to make is too much at that resolution.
|
||||
For example, if an image is only 512x768 total, the face might only be 128x128
|
||||
or 256x256, much smaller than the 512x512 your SD1.5 model was probably
|
||||
trained on. Try increasing the resolution of the image by upscaling or
|
||||
resizing, add padding to increase the bounding box's resolution, or use an
|
||||
image where the face takes up more pixels.
|
||||
- If the resulting face seems out of place pasted back on the original image
|
||||
(ie. too large, not proportional), add more padding on the FaceOff node to
|
||||
give inpainting more context. Context and good prompting are important to
|
||||
keeping things proportional.
|
||||
- If you find the mask is too big/small and going too far outside/inside the
|
||||
area you want to affect, adjust the x & y offsets to shrink/grow the mask area
|
||||
- Use a higher denoise start value to resemble aspects of the original face or
|
||||
surroundings. Denoise start = 0 & denoise end = 1 will make something new,
|
||||
while denoise start = 0.50 & denoise end = 1 will be 50% old and 50% new.
|
||||
- mediapipe isn't good at detecting faces with lots of face paint, hair covering
|
||||
the face, etc. Anything that obstructs the face will likely result in no faces
|
||||
being detected.
|
||||
- If you find your face isn't being detected, try lowering the minimum
|
||||
confidence value from 0.5. This could result in false positives, however
|
||||
(random areas being detected as faces and masked).
|
||||
- After altering an image and wanting to process a different face in the newly
|
||||
altered image, run the altered image through FaceIdentifier again to see the
|
||||
new Face IDs. MediaPipe will most likely detect faces in a different order
|
||||
after an image has been changed in the slightest.
|
||||
@@ -9,5 +9,6 @@ If you're interested in finding more workflows, checkout the [#share-your-workfl
|
||||
* [SD1.5 / SD2 Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Text_to_Image.json)
|
||||
* [SDXL Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_Text_to_Image.json)
|
||||
* [SDXL (with Refiner) Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_Text_to_Image.json)
|
||||
* [Tiled Upscaling with ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/ESRGAN_img2img_upscale w_Canny_ControlNet.json)ß
|
||||
|
||||
* [Tiled Upscaling with ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/ESRGAN_img2img_upscale w_Canny_ControlNet.json)
|
||||
* [FaceMask](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceMask.json)
|
||||
* [FaceOff with 2x Face Scaling](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceOff_FaceScale2x.json)
|
||||
|
||||
1041
docs/workflows/FaceMask.json
Normal file
1041
docs/workflows/FaceMask.json
Normal file
File diff suppressed because it is too large
Load Diff
1395
docs/workflows/FaceOff_FaceScale2x.json
Normal file
1395
docs/workflows/FaceOff_FaceScale2x.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -332,6 +332,7 @@ class InvokeAiInstance:
|
||||
Configure the InvokeAI runtime directory
|
||||
"""
|
||||
|
||||
auto_install = False
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1, len(sys.argv)):
|
||||
@@ -340,13 +341,17 @@ class InvokeAiInstance:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
new_argv.append(el)
|
||||
auto_install = True
|
||||
sys.argv = new_argv
|
||||
|
||||
import messages
|
||||
import requests # to catch download exceptions
|
||||
from messages import introduction
|
||||
|
||||
introduction()
|
||||
auto_install = auto_install or messages.user_wants_auto_configuration()
|
||||
if auto_install:
|
||||
sys.argv.append("--yes")
|
||||
else:
|
||||
messages.introduction()
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit import HTML, prompt
|
||||
from prompt_toolkit.completion import PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
@@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
dest_confirmed = Confirm.ask(
|
||||
":stop_sign: Are you sure you want to (re)install in this location?",
|
||||
":stop_sign: (re)install in this location?",
|
||||
default=False,
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||
dest_confirmed = Confirm.ask("Use this location?", default=True)
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def user_wants_auto_configuration() -> bool:
|
||||
"""Prompt the user to choose between manual and auto configuration."""
|
||||
console.rule("InvokeAI Configuration Section")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
|
||||
"",
|
||||
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
|
||||
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
|
||||
"",
|
||||
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
|
||||
]
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = (
|
||||
prompt(
|
||||
HTML("Choose <b><a></b>utomatic or <b><m></b>anual configuration [a/m] (a): "),
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
|
||||
),
|
||||
)
|
||||
or "a"
|
||||
)
|
||||
return choice.lower().startswith("a")
|
||||
|
||||
|
||||
def dest_path(dest=None) -> Path:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
||||
@@ -17,9 +17,10 @@ echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
echo 10. Run the InvokeAI image database maintenance script
|
||||
echo 11. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [1] "
|
||||
set /P choice="Please enter 1-11, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
@@ -58,8 +59,11 @@ IF /I "%choice%" == "1" (
|
||||
echo Running invokeai-update...
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
echo Running the db maintenance script...
|
||||
python .venv\Scripts\invokeai-db-maintenance.exe
|
||||
) ELSE IF /I "%choice%" == "11" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
exit /b
|
||||
) ELSE IF /I "%choice%" == "q" (
|
||||
|
||||
@@ -46,6 +46,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
fi
|
||||
|
||||
# Avoid glibc memory fragmentation. See invokeai/backend/model_management/README.md for details.
|
||||
export MALLOC_MMAP_THRESHOLD_=1048576
|
||||
|
||||
# Primary function for the case statement to determine user input
|
||||
do_choice() {
|
||||
case $1 in
|
||||
@@ -97,13 +100,13 @@ do_choice() {
|
||||
;;
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
printf "Running the db maintenance script\n"
|
||||
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
"HELP 1")
|
||||
11)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai --help
|
||||
invokeai-web --help
|
||||
;;
|
||||
*)
|
||||
clear
|
||||
@@ -125,7 +128,10 @@ do_dialog() {
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
9 "Update InvokeAI"
|
||||
10 "Run the InvokeAI image database maintenance script"
|
||||
11 "Command-line help"
|
||||
)
|
||||
|
||||
choice=$(dialog --clear \
|
||||
--backtitle "\Zb\Zu\Z3InvokeAI" \
|
||||
@@ -157,9 +163,10 @@ do_line_input() {
|
||||
printf "7: Re-run the configure script to fix a broken install\n"
|
||||
printf "8: Open the developer console\n"
|
||||
printf "9: Update InvokeAI\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "10: Run the InvokeAI image database maintenance script\n"
|
||||
printf "11: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
read -p "Please enter 1-11, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
clear
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||
@@ -9,12 +10,16 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
@@ -22,9 +27,12 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.model_install_service import ModelInstallService
|
||||
from ..services.model_loader_service import ModelLoadService
|
||||
from ..services.model_record_service import ModelRecordServiceBase
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.thread import lock
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -44,7 +52,7 @@ def check_internet() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class ApiDependencies:
|
||||
@@ -63,22 +71,32 @@ class ApiDependencies:
|
||||
output_folder = config.output_path
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
if config.use_memory_db:
|
||||
db_location = ":memory:"
|
||||
else:
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
logger.info(f"Using database at {db_location}")
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
|
||||
if config.log_sql:
|
||||
db_conn.set_trace_callback(print)
|
||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -112,26 +130,45 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
download_queue = DownloadQueueService(event_bus=events)
|
||||
model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=lock)
|
||||
model_loader = ModelLoadService(config, model_record_store)
|
||||
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
download_queue=download_queue,
|
||||
model_record_store=model_record_store,
|
||||
model_loader=model_loader,
|
||||
model_installer=model_installer,
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
try:
|
||||
lock.acquire()
|
||||
db_conn.execute("VACUUM;")
|
||||
db_conn.commit()
|
||||
logger.info("Cleaned database")
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
if ApiDependencies.invoker:
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
@@ -103,3 +104,43 @@ async def set_log_level(
|
||||
"""Sets the log verbosity level"""
|
||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
||||
|
||||
@app_router.delete(
|
||||
"/invocation_cache",
|
||||
operation_id="clear_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def clear_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.clear()
|
||||
|
||||
|
||||
@app_router.put(
|
||||
"/invocation_cache/enable",
|
||||
operation_id="enable_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def enable_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.enable()
|
||||
|
||||
|
||||
@app_router.put(
|
||||
"/invocation_cache/disable",
|
||||
operation_id="disable_invocation_cache",
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
)
|
||||
async def disable_invocation_cache() -> None:
|
||||
"""Clears the invocation cache"""
|
||||
ApiDependencies.invoker.services.invocation_cache.disable()
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/invocation_cache/status",
|
||||
operation_id="get_invocation_cache_status",
|
||||
responses={200: {"model": InvocationCacheStatus}},
|
||||
)
|
||||
async def get_invocation_cache_status() -> InvocationCacheStatus:
|
||||
"""Clears the invocation cache"""
|
||||
return ApiDependencies.invoker.services.invocation_cache.get_status()
|
||||
|
||||
@@ -2,35 +2,60 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.download_manager import DownloadJobRemoteSource, DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management.models import (
|
||||
from invokeai.backend.model_manager import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
ModelConfigBase,
|
||||
ModelSearch,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
|
||||
# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker
|
||||
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
|
||||
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
|
||||
|
||||
# There are still numerous mypy errors here because it does not seem to like this
|
||||
# way of dynamically generating the typing hints below.
|
||||
InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
models: List[InvokeAIModelConfig]
|
||||
|
||||
|
||||
class ModelDownloadStatus(BaseModel):
|
||||
"""Return information about a background installation job."""
|
||||
|
||||
job_id: int
|
||||
source: str
|
||||
priority: int
|
||||
bytes: int
|
||||
total_bytes: int
|
||||
status: DownloadJobStatus
|
||||
|
||||
|
||||
class JobControlOperation(str, Enum):
|
||||
START = "Start"
|
||||
PAUSE = "Pause"
|
||||
CANCEL = "Cancel"
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@@ -42,19 +67,22 @@ async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
models_raw.extend(
|
||||
[x.dict() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)]
|
||||
)
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models_raw = [x.dict() for x in record_store.search_by_name(model_type=model_type)]
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
"/i/{key}",
|
||||
operation_id="update_model",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
@@ -63,69 +91,36 @@ async def list_models(
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
key: str = Path(description="Unique key of model"),
|
||||
info: InvokeAIModelConfig = Body(description="Model configuration"),
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
info_dict = info.dict()
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
model_install = ApiDependencies.invoker.services.model_installer
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
new_config = record_store.update_model(key, config=info_dict)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
|
||||
try:
|
||||
# In the event that the model's name, type or base has changed, and the model itself
|
||||
# resides in the invokeai root models directory, then the next statement will move
|
||||
# the model file into its new canonical location.
|
||||
new_config = model_install.sync_model_path(new_config.key)
|
||||
model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@@ -141,37 +136,55 @@ async def update_model(
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
response_model=ModelDownloadStatus,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
||||
default=None,
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
priority: Optional[int] = Body(
|
||||
description="Which import jobs run first. Lower values run before higher ones.",
|
||||
default=10,
|
||||
),
|
||||
) -> ModelDownloadStatus:
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has a `job_id` property
|
||||
that can be used to control the download job.
|
||||
|
||||
The priority controls which import jobs run first. Lower values run before
|
||||
higher ones.
|
||||
|
||||
The prediction_type applies to SDv2 models only and can be one of
|
||||
"v_prediction", "epsilon", or "sample". Default if not provided is
|
||||
"v_prediction".
|
||||
|
||||
Listen on the event bus for a series of `model_event` events with an `id`
|
||||
matching the returned job id to get the progress, completion status, errors,
|
||||
and information on the model that was installed.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
result = installer.install_model(
|
||||
location,
|
||||
probe_override={"prediction_type": SchedulerPredictionType(prediction_type) if prediction_type else None},
|
||||
priority=priority,
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
return ModelDownloadStatus(
|
||||
job_id=result.id,
|
||||
source=result.source,
|
||||
priority=result.priority,
|
||||
bytes=result.bytes,
|
||||
total_bytes=result.total_bytes,
|
||||
status=result.status,
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
@@ -188,29 +201,40 @@ async def import_model(
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
info: InvokeAIModelConfig = Body(description="Model configuration"),
|
||||
) -> InvokeAIModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
|
||||
This call will block until the model is installed.
|
||||
"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
path = info.path
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
record_store = ApiDependencies.invoker.services.model_record_store
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
key = installer.install_path(path)
|
||||
logger.info(f"Created model {key} for {path}")
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# update with the provided info
|
||||
try:
|
||||
info_dict = info.dict()
|
||||
new_config = record_store.update_model(key, new_config=info_dict)
|
||||
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
@@ -219,33 +243,34 @@ async def add_model(
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
"/i/{key}",
|
||||
operation_id="del_model",
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
delete_files: Optional[bool] = Query(description="Delete underlying files and directories as well.", default=False),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
if delete_files:
|
||||
installer.delete(key)
|
||||
else:
|
||||
installer.unregister(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
"/convert/{key}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
@@ -253,33 +278,26 @@ async def delete_model(
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
key: str = Path(description="Unique key of model to convert from checkpoint/safetensors to diffusers format."),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
model_config = converter.convert_model(key, dest_directory=dest)
|
||||
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -298,11 +316,12 @@ async def convert_model(
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
"""Search for all models in a server-local path."""
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
return ModelSearch().search(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@@ -316,7 +335,10 @@ async def search_for_models(
|
||||
)
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
|
||||
@models_router.post(
|
||||
@@ -329,27 +351,32 @@ async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
"""
|
||||
Synchronize model in-memory data structures with disk.
|
||||
|
||||
Call after making changes to models.yaml, autoimport directories
|
||||
or models directory.
|
||||
"""
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
installer.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
"/merge",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
409: {"description": "An identical merged model is already installed"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
response_model=InvokeAIModelConfig,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
@@ -359,29 +386,147 @@ async def merge_models(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
) -> InvokeAIModelConfig:
|
||||
"""Merge the indicated diffusers model."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store,
|
||||
)
|
||||
result: ModelConfigBase = converter.merge_models(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
response = parse_obj_as(InvokeAIModelConfig, result.dict())
|
||||
except DuplicateModelException as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except UnknownModelException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/jobs",
|
||||
operation_id="list_install_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[ModelDownloadStatus],
|
||||
)
|
||||
async def list_install_jobs() -> List[ModelDownloadStatus]:
|
||||
"""List active and pending model installation jobs."""
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
jobs = job_mgr.list_jobs()
|
||||
return [
|
||||
ModelDownloadStatus(
|
||||
job_id=x.id,
|
||||
source=x.source,
|
||||
priority=x.priority,
|
||||
bytes=x.bytes,
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs
|
||||
if isinstance(x, ModelInstallJob)
|
||||
]
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/control/{operation}/{job_id}",
|
||||
operation_id="control_download_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The job could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelDownloadStatus,
|
||||
)
|
||||
async def control_download_jobs(
|
||||
job_id: int = Path(description="Download/install job_id for start, pause and cancel operations"),
|
||||
operation: JobControlOperation = Path(description="The operation to perform on the job."),
|
||||
priority_delta: Optional[int] = Body(
|
||||
description="Change in job priority for priority operations only. Negative numbers increase priority.",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelDownloadStatus:
|
||||
"""Start, pause, cancel, or change the run priority of a running model install job."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
try:
|
||||
job = job_mgr.id_to_job(job_id)
|
||||
|
||||
if operation == JobControlOperation.START:
|
||||
job_mgr.start_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.PAUSE:
|
||||
job_mgr.pause_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.CANCEL:
|
||||
job_mgr.cancel_job(job_id)
|
||||
|
||||
else:
|
||||
raise ValueError("unknown operation {operation}")
|
||||
bytes = 0
|
||||
total_bytes = 0
|
||||
if isinstance(job, DownloadJobRemoteSource):
|
||||
bytes = job.bytes
|
||||
total_bytes = job.total_bytes
|
||||
|
||||
return ModelDownloadStatus(
|
||||
job_id=job_id,
|
||||
source=job.source,
|
||||
priority=job.priority,
|
||||
status=job.status,
|
||||
bytes=bytes,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
except UnknownJobIDException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/cancel_all",
|
||||
operation_id="cancel_all_download_jobs",
|
||||
responses={
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def cancel_all_download_jobs():
|
||||
"""Cancel all model installation jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
logger.info("Cancelling all download jobs.")
|
||||
job_mgr.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/prune",
|
||||
operation_id="prune_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_jobs():
|
||||
"""Prune all completed and errored jobs."""
|
||||
mgr = ApiDependencies.invoker.services.download_queue
|
||||
mgr.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
247
invokeai/app/api/routers/session_queue.py
Normal file
247
invokeai/app/api/routers/session_queue.py
Normal file
@@ -0,0 +1,247 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
PruneResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
from ...services.graph import Graph
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
|
||||
|
||||
class SessionQueueAndProcessorStatus(BaseModel):
|
||||
"""The overall status of session queue and processor"""
|
||||
|
||||
queue: SessionQueueStatus
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_graph",
|
||||
operation_id="enqueue_graph",
|
||||
responses={
|
||||
201: {"model": EnqueueGraphResult},
|
||||
},
|
||||
)
|
||||
async def enqueue_graph(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
graph: Graph = Body(description="The graph to enqueue"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
) -> EnqueueGraphResult:
|
||||
"""Enqueues a graph for single execution."""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_graph(queue_id=queue_id, graph=graph, prepend=prepend)
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
responses={
|
||||
201: {"model": EnqueueBatchResult},
|
||||
},
|
||||
)
|
||||
async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/list",
|
||||
operation_id="list_queue_items",
|
||||
responses={
|
||||
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_queue_items(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
limit: int = Query(default=50, description="The number of items to fetch"),
|
||||
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
|
||||
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
|
||||
priority: int = Query(default=0, description="The pagination cursor priority"),
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets all queue items (without graphs)"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.list_queue_items(
|
||||
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/processor/resume",
|
||||
operation_id="resume",
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def resume(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/processor/pause",
|
||||
operation_id="pause",
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def Pause(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_batch_ids",
|
||||
operation_id="cancel_by_batch_ids",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_batch_ids(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/clear",
|
||||
operation_id="clear",
|
||||
responses={
|
||||
200: {"model": ClearResult},
|
||||
},
|
||||
)
|
||||
async def clear(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/prune",
|
||||
operation_id="prune",
|
||||
responses={
|
||||
200: {"model": PruneResult},
|
||||
},
|
||||
)
|
||||
async def prune(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/current",
|
||||
operation_id="get_current_queue_item",
|
||||
responses={
|
||||
200: {"model": Optional[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def get_current_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/next",
|
||||
operation_id="get_next_queue_item",
|
||||
responses={
|
||||
200: {"model": Optional[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def get_next_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/status",
|
||||
operation_id="get_queue_status",
|
||||
responses={
|
||||
200: {"model": SessionQueueAndProcessorStatus},
|
||||
},
|
||||
)
|
||||
async def get_queue_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/b/{batch_id}/status",
|
||||
operation_id="get_batch_status",
|
||||
responses={
|
||||
200: {"model": BatchStatus},
|
||||
},
|
||||
)
|
||||
async def get_batch_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/i/{item_id}",
|
||||
operation_id="get_queue_item",
|
||||
responses={
|
||||
200: {"model": SessionQueueItem},
|
||||
},
|
||||
)
|
||||
async def get_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
"""Gets a queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/i/{item_id}/cancel",
|
||||
operation_id="cancel_queue_item",
|
||||
responses={
|
||||
200: {"model": SessionQueueItem},
|
||||
},
|
||||
)
|
||||
async def cancel_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
@@ -23,12 +23,14 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
200: {"model": GraphExecutionState},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||
queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||
return session
|
||||
|
||||
|
||||
@@ -36,6 +38,7 @@ async def create_session(
|
||||
"/",
|
||||
operation_id="list_sessions",
|
||||
responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
deprecated=True,
|
||||
)
|
||||
async def list_sessions(
|
||||
page: int = Query(default=0, description="The page of results to get"),
|
||||
@@ -57,6 +60,7 @@ async def list_sessions(
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
@@ -77,6 +81,7 @@ async def get_session(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -109,6 +114,7 @@ async def add_node(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -142,6 +148,7 @@ async def update_node(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def delete_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -172,6 +179,7 @@ async def delete_node(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def add_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -203,6 +211,7 @@ async def add_edge(
|
||||
400: {"description": "Invalid node or link"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def delete_edge(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
@@ -241,8 +250,10 @@ async def delete_edge(
|
||||
400: {"description": "The session has no invocations ready to invoke"},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
deprecated=True,
|
||||
)
|
||||
async def invoke_session(
|
||||
queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
) -> Response:
|
||||
@@ -254,7 +265,7 @@ async def invoke_session(
|
||||
if session.is_complete():
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@@ -262,6 +273,7 @@ async def invoke_session(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={202: {"description": "The invocation is canceled"}},
|
||||
deprecated=True,
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
|
||||
41
invokeai/app/api/routers/utilities.py
Normal file
41
invokeai/app/api/routers/utilities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pyparsing import ParseException
|
||||
|
||||
utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])
|
||||
|
||||
|
||||
class DynamicPromptsResponse(BaseModel):
|
||||
prompts: list[str]
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@utilities_router.post(
|
||||
"/dynamicprompts",
|
||||
operation_id="parse_dynamicprompts",
|
||||
responses={
|
||||
200: {"model": DynamicPromptsResponse},
|
||||
},
|
||||
)
|
||||
async def parse_dynamicprompts(
|
||||
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
|
||||
max_prompts: int = Body(default=1000, description="The max number of prompts to generate"),
|
||||
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
|
||||
) -> DynamicPromptsResponse:
|
||||
"""Creates a batch process"""
|
||||
try:
|
||||
error: Optional[str] = None
|
||||
if combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(prompt, max_prompts=max_prompts)
|
||||
else:
|
||||
generator = RandomPromptGenerator()
|
||||
prompts = generator.generate(prompt, num_images=max_prompts)
|
||||
except ParseException as e:
|
||||
prompts = [prompt]
|
||||
error = str(e)
|
||||
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)
|
||||
@@ -3,34 +3,35 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from fastapi_socketio import SocketManager
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: SocketManager
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
room=event[1]["data"]["queue_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
||||
if "queue_id" in data:
|
||||
self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||
if "queue_id" in data:
|
||||
self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# Copyright (c) 2022-2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
@@ -9,7 +8,6 @@ app_config.parse_args()
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from inspect import signature
|
||||
@@ -33,7 +31,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import app_info, board_images, boards, images, models, sessions
|
||||
from .api.routers import app_info, board_images, boards, images, models, session_queue, sessions, utilities
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
|
||||
@@ -42,7 +40,9 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
@@ -92,6 +92,8 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
@@ -102,6 +104,8 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix="/api")
|
||||
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
@@ -147,7 +151,7 @@ def custom_openapi():
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
from invokeai.backend.model_manager.models import get_model_config_enums
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
@@ -197,6 +201,10 @@ app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=Tru
|
||||
|
||||
|
||||
def invoke_api():
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
return
|
||||
|
||||
def find_port(port: int):
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
@@ -220,7 +228,7 @@ def invoke_api():
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.getLogger(name="jurigged").info)
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
@@ -239,7 +247,7 @@ def invoke_api():
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
log = logging.getLogger(logname)
|
||||
log = InvokeAILogger.get_logger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
log.addHandler(ch)
|
||||
@@ -248,7 +256,4 @@ def invoke_api():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
invoke_api()
|
||||
|
||||
@@ -10,10 +10,11 @@ from pathlib import Path
|
||||
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.model_record_service import ModelRecordServiceBase
|
||||
from .commands import BaseCommand
|
||||
|
||||
# singleton object, class variable
|
||||
@@ -21,11 +22,11 @@ completer = None
|
||||
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, model_manager: ModelManager):
|
||||
def __init__(self, model_record_store: ModelRecordServiceBase):
|
||||
self.commands = self.get_commands()
|
||||
self.matches = None
|
||||
self.linebuffer = None
|
||||
self.manager = model_manager
|
||||
self.store = model_record_store
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
@@ -127,7 +128,7 @@ class Completer(object):
|
||||
if get_origin(typehint) == Literal:
|
||||
return get_args(typehint)
|
||||
if parameter == "model":
|
||||
return self.manager.model_names()
|
||||
return [x.name for x in self.store.model_info_by_name(model_type=ModelType.Main)]
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
@@ -142,7 +143,7 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(services.model_manager)
|
||||
completer = Completer(services.model_record_store)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
try:
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import argparse
|
||||
import re
|
||||
import shlex
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, Union, get_type_hints
|
||||
@@ -29,6 +30,8 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@@ -37,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||
from .services.download_manager import DownloadQueueService
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (
|
||||
Edge,
|
||||
@@ -51,15 +55,19 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.model_install_service import ModelInstallService
|
||||
from .services.model_loader_service import ModelLoadService
|
||||
from .services.model_record_service import ModelRecordServiceBase
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
from .services.thread import lock
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
@@ -226,7 +234,12 @@ def invoke_all(context: CliContext):
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
return
|
||||
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument("commands", nargs="*")
|
||||
@@ -237,8 +250,6 @@ def invoke_cli():
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile, "r")
|
||||
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
|
||||
events = EventServiceBase()
|
||||
output_folder = config.output_path
|
||||
|
||||
@@ -249,19 +260,25 @@ def invoke_cli():
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
download_queue = DownloadQueueService(event_bus=events)
|
||||
model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=None)
|
||||
model_loader = ModelLoadService(config, model_record_store)
|
||||
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@@ -296,19 +313,25 @@ def invoke_cli():
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
download_queue=download_queue,
|
||||
model_record_store=model_record_store,
|
||||
model_loader=model_loader,
|
||||
model_installer=model_installer,
|
||||
configuration=config,
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@@ -476,7 +499,4 @@ def invoke_cli():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_cli()
|
||||
invoke_cli()
|
||||
|
||||
@@ -68,6 +68,7 @@ class FieldDescriptions:
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
@@ -88,6 +89,12 @@ class FieldDescriptions:
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
inclusive_low = "The inclusive low value"
|
||||
exclusive_high = "The exclusive high value"
|
||||
decimal_places = "The number of decimal places to round to"
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
@@ -173,6 +180,7 @@ class UIType(str, Enum):
|
||||
WorkflowField = "WorkflowField"
|
||||
IsIntermediate = "IsIntermediate"
|
||||
MetadataField = "MetadataField"
|
||||
BoardField = "BoardField"
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -419,12 +427,27 @@ class UIConfigBase(BaseModel):
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
|
||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
@@ -522,6 +545,9 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
validate_all = True
|
||||
|
||||
@staticmethod
|
||||
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
@@ -570,7 +596,29 @@ class BaseInvocation(ABC, BaseModel):
|
||||
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
|
||||
elif _input == Input.Any:
|
||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||
return self.invoke(context)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
key = context.services.invocation_cache.create_key(self)
|
||||
cached_value = context.services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
context.services.invocation_cache.save(key, output)
|
||||
return output
|
||||
else:
|
||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
return cached_value
|
||||
else:
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.__fields__["type"].default
|
||||
|
||||
id: str = Field(
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||
@@ -583,6 +631,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
description="The workflow to save with the image",
|
||||
ui_type=UIType.WorkflowField,
|
||||
)
|
||||
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
|
||||
|
||||
@validator("workflow", pre=True)
|
||||
def validate_workflow_is_json(cls, v):
|
||||
@@ -606,6 +655,7 @@ def invocation(
|
||||
tags: Optional[list[str]] = None,
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
@@ -614,6 +664,8 @@ def invocation(
|
||||
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||
:param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
|
||||
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
|
||||
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
|
||||
@@ -638,6 +690,8 @@ def invocation(
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
if use_cache is not None:
|
||||
cls.__fields__["use_cache"].default = use_cache
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
|
||||
@@ -56,6 +56,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
@@ -13,8 +13,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.model_manager import ModelType, UnknownModelException
|
||||
from ...backend.model_manager.lora import ModelPatcher
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -60,23 +60,23 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_loader.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_loader.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
@@ -85,7 +85,7 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@@ -93,7 +93,7 @@ class CompelInvocation(BaseInvocation):
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
@@ -159,11 +159,11 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_loader.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_loader.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -186,12 +186,12 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
@@ -200,7 +200,7 @@ class SDXLPromptInvocationBase:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@@ -208,7 +208,7 @@ class SDXLPromptInvocationBase:
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
|
||||
@@ -28,7 +28,7 @@ from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -38,7 +38,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -100,7 +99,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
default=1.0, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
@@ -560,3 +559,33 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
img[:, :] = ann_color
|
||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||
return np.array(final_img, dtype=np.uint8)
|
||||
|
||||
|
||||
@invocation(
|
||||
"color_map_image_processor",
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
image = image.convert("RGB")
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
height, width = image.shape[:2]
|
||||
|
||||
width_tile_size = min(self.color_map_tile_size, width)
|
||||
height_tile_size = min(self.color_map_tile_size, height)
|
||||
|
||||
color_map = cv2.resize(
|
||||
image,
|
||||
(width // width_tile_size, height // height_tile_size),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
color_map = Image.fromarray(color_map)
|
||||
return color_map
|
||||
|
||||
692
invokeai/app/invocations/facetools.py
Normal file
692
invokeai/app/invocations/facetools.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mediapipe.python.solutions.face_mesh import FaceMesh # type: ignore[import]
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
|
||||
from PIL.Image import Image as ImageType
|
||||
from pydantic import validator
|
||||
|
||||
import invokeai.assets.fonts as font_assets
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
|
||||
|
||||
@invocation_output("face_mask_output")
|
||||
class FaceMaskOutput(ImageOutput):
|
||||
"""Base class for FaceMask output"""
|
||||
|
||||
mask: ImageField = OutputField(description="The output mask")
|
||||
|
||||
|
||||
@invocation_output("face_off_output")
|
||||
class FaceOffOutput(ImageOutput):
|
||||
"""Base class for FaceOff Output"""
|
||||
|
||||
mask: ImageField = OutputField(description="The output mask")
|
||||
x: int = OutputField(description="The x coordinate of the bounding box's left side")
|
||||
y: int = OutputField(description="The y coordinate of the bounding box's top side")
|
||||
|
||||
|
||||
class FaceResultData(TypedDict):
|
||||
image: ImageType
|
||||
mask: ImageType
|
||||
x_center: float
|
||||
y_center: float
|
||||
mesh_width: int
|
||||
mesh_height: int
|
||||
|
||||
|
||||
class FaceResultDataWithId(FaceResultData):
|
||||
face_id: int
|
||||
|
||||
|
||||
class ExtractFaceData(TypedDict):
|
||||
bounded_image: ImageType
|
||||
bounded_mask: ImageType
|
||||
x_min: int
|
||||
y_min: int
|
||||
x_max: int
|
||||
y_max: int
|
||||
|
||||
|
||||
class FaceMaskResult(TypedDict):
|
||||
image: ImageType
|
||||
mask: ImageType
|
||||
|
||||
|
||||
def create_white_image(w: int, h: int) -> ImageType:
|
||||
return Image.new("L", (w, h), color=255)
|
||||
|
||||
|
||||
def create_black_image(w: int, h: int) -> ImageType:
|
||||
return Image.new("L", (w, h), color=0)
|
||||
|
||||
|
||||
FONT_SIZE = 32
|
||||
FONT_STROKE_WIDTH = 4
|
||||
|
||||
|
||||
def prepare_faces_list(
|
||||
face_result_list: list[FaceResultData],
|
||||
) -> list[FaceResultDataWithId]:
|
||||
"""Deduplicates a list of faces, adding IDs to them."""
|
||||
deduped_faces: list[FaceResultData] = []
|
||||
|
||||
if len(face_result_list) == 0:
|
||||
return list()
|
||||
|
||||
for candidate in face_result_list:
|
||||
should_add = True
|
||||
candidate_x_center = candidate["x_center"]
|
||||
candidate_y_center = candidate["y_center"]
|
||||
for face in deduped_faces:
|
||||
face_center_x = face["x_center"]
|
||||
face_center_y = face["y_center"]
|
||||
face_radius_w = face["mesh_width"] / 2
|
||||
face_radius_h = face["mesh_height"] / 2
|
||||
# Determine if the center of the candidate_face is inside the ellipse of the added face
|
||||
# p < 1 -> Inside
|
||||
# p = 1 -> Exactly on the ellipse
|
||||
# p > 1 -> Outside
|
||||
p = (math.pow((candidate_x_center - face_center_x), 2) / math.pow(face_radius_w, 2)) + (
|
||||
math.pow((candidate_y_center - face_center_y), 2) / math.pow(face_radius_h, 2)
|
||||
)
|
||||
|
||||
if p < 1: # Inside of the already-added face's radius
|
||||
should_add = False
|
||||
break
|
||||
|
||||
if should_add is True:
|
||||
deduped_faces.append(candidate)
|
||||
|
||||
sorted_faces = sorted(deduped_faces, key=lambda x: x["y_center"])
|
||||
sorted_faces = sorted(sorted_faces, key=lambda x: x["x_center"])
|
||||
|
||||
# add face_id for reference
|
||||
sorted_faces_with_ids: list[FaceResultDataWithId] = []
|
||||
face_id_counter = 0
|
||||
for face in sorted_faces:
|
||||
sorted_faces_with_ids.append(
|
||||
FaceResultDataWithId(
|
||||
**face,
|
||||
face_id=face_id_counter,
|
||||
)
|
||||
)
|
||||
face_id_counter += 1
|
||||
|
||||
return sorted_faces_with_ids
|
||||
|
||||
|
||||
def generate_face_box_mask(
|
||||
context: InvocationContext,
|
||||
minimum_confidence: float,
|
||||
x_offset: float,
|
||||
y_offset: float,
|
||||
pil_image: ImageType,
|
||||
chunk_x_offset: int = 0,
|
||||
chunk_y_offset: int = 0,
|
||||
draw_mesh: bool = True,
|
||||
check_bounds: bool = True,
|
||||
) -> list[FaceResultData]:
|
||||
result = []
|
||||
mask_pil = None
|
||||
|
||||
# Convert the PIL image to a NumPy array.
|
||||
np_image = np.array(pil_image, dtype=np.uint8)
|
||||
|
||||
# Check if the input image has four channels (RGBA).
|
||||
if np_image.shape[2] == 4:
|
||||
# Convert RGBA to RGB by removing the alpha channel.
|
||||
np_image = np_image[:, :, :3]
|
||||
|
||||
# Create a FaceMesh object for face landmark detection and mesh generation.
|
||||
face_mesh = FaceMesh(
|
||||
max_num_faces=999,
|
||||
min_detection_confidence=minimum_confidence,
|
||||
min_tracking_confidence=minimum_confidence,
|
||||
)
|
||||
|
||||
# Detect the face landmarks and mesh in the input image.
|
||||
results = face_mesh.process(np_image)
|
||||
|
||||
# Check if any face is detected.
|
||||
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
|
||||
# Search for the face_id in the detected faces.
|
||||
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
|
||||
# Get the bounding box of the face mesh.
|
||||
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
|
||||
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
|
||||
x_min, x_max = min(x_coordinates), max(x_coordinates)
|
||||
y_min, y_max = min(y_coordinates), max(y_coordinates)
|
||||
|
||||
# Calculate the width and height of the face mesh.
|
||||
mesh_width = int((x_max - x_min) * np_image.shape[1])
|
||||
mesh_height = int((y_max - y_min) * np_image.shape[0])
|
||||
|
||||
# Get the center of the face.
|
||||
x_center = np.mean([landmark.x * np_image.shape[1] for landmark in face_landmarks.landmark])
|
||||
y_center = np.mean([landmark.y * np_image.shape[0] for landmark in face_landmarks.landmark])
|
||||
|
||||
face_landmark_points = np.array(
|
||||
[
|
||||
[landmark.x * np_image.shape[1], landmark.y * np_image.shape[0]]
|
||||
for landmark in face_landmarks.landmark
|
||||
]
|
||||
)
|
||||
|
||||
# Apply the scaling offsets to the face landmark points with a multiplier.
|
||||
scale_multiplier = 0.2
|
||||
x_center = np.mean(face_landmark_points[:, 0])
|
||||
y_center = np.mean(face_landmark_points[:, 1])
|
||||
|
||||
if draw_mesh:
|
||||
x_scaled = face_landmark_points[:, 0] + scale_multiplier * x_offset * (
|
||||
face_landmark_points[:, 0] - x_center
|
||||
)
|
||||
y_scaled = face_landmark_points[:, 1] + scale_multiplier * y_offset * (
|
||||
face_landmark_points[:, 1] - y_center
|
||||
)
|
||||
|
||||
convex_hull = cv2.convexHull(np.column_stack((x_scaled, y_scaled)).astype(np.int32))
|
||||
|
||||
# Generate a binary face mask using the face mesh.
|
||||
mask_image = np.ones(np_image.shape[:2], dtype=np.uint8) * 255
|
||||
cv2.fillConvexPoly(mask_image, convex_hull, 0)
|
||||
|
||||
# Convert the binary mask image to a PIL Image.
|
||||
init_mask_pil = Image.fromarray(mask_image, mode="L")
|
||||
w, h = init_mask_pil.size
|
||||
mask_pil = create_white_image(w + chunk_x_offset, h + chunk_y_offset)
|
||||
mask_pil.paste(init_mask_pil, (chunk_x_offset, chunk_y_offset))
|
||||
|
||||
left_side = x_center - mesh_width
|
||||
right_side = x_center + mesh_width
|
||||
top_side = y_center - mesh_height
|
||||
bottom_side = y_center + mesh_height
|
||||
im_width, im_height = pil_image.size
|
||||
over_w = im_width * 0.1
|
||||
over_h = im_height * 0.1
|
||||
if not check_bounds or (
|
||||
(left_side >= -over_w)
|
||||
and (right_side < im_width + over_w)
|
||||
and (top_side >= -over_h)
|
||||
and (bottom_side < im_height + over_h)
|
||||
):
|
||||
x_center = float(x_center)
|
||||
y_center = float(y_center)
|
||||
face = FaceResultData(
|
||||
image=pil_image,
|
||||
mask=mask_pil or create_white_image(*pil_image.size),
|
||||
x_center=x_center + chunk_x_offset,
|
||||
y_center=y_center + chunk_y_offset,
|
||||
mesh_width=mesh_width,
|
||||
mesh_height=mesh_height,
|
||||
)
|
||||
|
||||
result.append(face)
|
||||
else:
|
||||
context.services.logger.info("FaceTools --> Face out of bounds, ignoring.")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def extract_face(
|
||||
context: InvocationContext,
|
||||
image: ImageType,
|
||||
face: FaceResultData,
|
||||
padding: int,
|
||||
) -> ExtractFaceData:
|
||||
mask = face["mask"]
|
||||
center_x = face["x_center"]
|
||||
center_y = face["y_center"]
|
||||
mesh_width = face["mesh_width"]
|
||||
mesh_height = face["mesh_height"]
|
||||
|
||||
# Determine the minimum size of the square crop
|
||||
min_size = min(mask.width, mask.height)
|
||||
|
||||
# Calculate the crop boundaries for the output image and mask.
|
||||
mesh_width += 128 + padding # add pixels to account for mask variance
|
||||
mesh_height += 128 + padding # add pixels to account for mask variance
|
||||
crop_size = min(
|
||||
max(mesh_width, mesh_height, 128), min_size
|
||||
) # Choose the smaller of the two (given value or face mask size)
|
||||
if crop_size > 128:
|
||||
crop_size = (crop_size + 7) // 8 * 8 # Ensure crop side is multiple of 8
|
||||
|
||||
# Calculate the actual crop boundaries within the bounds of the original image.
|
||||
x_min = int(center_x - crop_size / 2)
|
||||
y_min = int(center_y - crop_size / 2)
|
||||
x_max = int(center_x + crop_size / 2)
|
||||
y_max = int(center_y + crop_size / 2)
|
||||
|
||||
# Adjust the crop boundaries to stay within the original image's dimensions
|
||||
if x_min < 0:
|
||||
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||
x_max -= x_min
|
||||
x_min = 0
|
||||
elif x_max > mask.width:
|
||||
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||
x_min -= x_max - mask.width
|
||||
x_max = mask.width
|
||||
|
||||
if y_min < 0:
|
||||
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||
y_max -= y_min
|
||||
y_min = 0
|
||||
elif y_max > mask.height:
|
||||
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||
y_min -= y_max - mask.height
|
||||
y_max = mask.height
|
||||
|
||||
# Ensure the crop is square and adjust the boundaries if needed
|
||||
if x_max - x_min != crop_size:
|
||||
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||
diff = crop_size - (x_max - x_min)
|
||||
x_min -= diff // 2
|
||||
x_max += diff - diff // 2
|
||||
|
||||
if y_max - y_min != crop_size:
|
||||
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||
diff = crop_size - (y_max - y_min)
|
||||
y_min -= diff // 2
|
||||
y_max += diff - diff // 2
|
||||
|
||||
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||
|
||||
# Crop the output image to the specified size with the center of the face mesh as the center.
|
||||
mask = mask.crop((x_min, y_min, x_max, y_max))
|
||||
bounded_image = image.crop((x_min, y_min, x_max, y_max))
|
||||
|
||||
# blur mask edge by small radius
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=2))
|
||||
|
||||
return ExtractFaceData(
|
||||
bounded_image=bounded_image,
|
||||
bounded_mask=mask,
|
||||
x_min=x_min,
|
||||
y_min=y_min,
|
||||
x_max=x_max,
|
||||
y_max=y_max,
|
||||
)
|
||||
|
||||
|
||||
def get_faces_list(
|
||||
context: InvocationContext,
|
||||
image: ImageType,
|
||||
should_chunk: bool,
|
||||
minimum_confidence: float,
|
||||
x_offset: float,
|
||||
y_offset: float,
|
||||
draw_mesh: bool = True,
|
||||
) -> list[FaceResultDataWithId]:
|
||||
result = []
|
||||
|
||||
# Generate the face box mask and get the center of the face.
|
||||
if not should_chunk:
|
||||
context.services.logger.info("FaceTools --> Attempting full image face detection.")
|
||||
result = generate_face_box_mask(
|
||||
context=context,
|
||||
minimum_confidence=minimum_confidence,
|
||||
x_offset=x_offset,
|
||||
y_offset=y_offset,
|
||||
pil_image=image,
|
||||
chunk_x_offset=0,
|
||||
chunk_y_offset=0,
|
||||
draw_mesh=draw_mesh,
|
||||
check_bounds=False,
|
||||
)
|
||||
if should_chunk or len(result) == 0:
|
||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||
width, height = image.size
|
||||
image_chunks = []
|
||||
x_offsets = []
|
||||
y_offsets = []
|
||||
result = []
|
||||
|
||||
# If width == height, there's nothing more we can do... otherwise...
|
||||
if width > height:
|
||||
# Landscape - slice the image horizontally
|
||||
fx = 0.0
|
||||
steps = int(width * 2 / height)
|
||||
while fx <= (width - height):
|
||||
x = int(fx)
|
||||
image_chunks.append(image.crop((x, 0, x + height - 1, height - 1)))
|
||||
x_offsets.append(x)
|
||||
y_offsets.append(0)
|
||||
fx += (width - height) / steps
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||
elif height > width:
|
||||
# Portrait - slice the image vertically
|
||||
fy = 0.0
|
||||
steps = int(height * 2 / width)
|
||||
while fy <= (height - width):
|
||||
y = int(fy)
|
||||
image_chunks.append(image.crop((0, y, width - 1, y + width - 1)))
|
||||
x_offsets.append(0)
|
||||
y_offsets.append(y)
|
||||
fy += (height - width) / steps
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||
|
||||
for idx in range(len(image_chunks)):
|
||||
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||
result = result + generate_face_box_mask(
|
||||
context=context,
|
||||
minimum_confidence=minimum_confidence,
|
||||
x_offset=x_offset,
|
||||
y_offset=y_offset,
|
||||
pil_image=image_chunks[idx],
|
||||
chunk_x_offset=x_offsets[idx],
|
||||
chunk_y_offset=y_offsets[idx],
|
||||
draw_mesh=draw_mesh,
|
||||
)
|
||||
|
||||
if len(result) == 0:
|
||||
# Give up
|
||||
context.services.logger.warning(
|
||||
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
||||
)
|
||||
|
||||
all_faces = prepare_faces_list(result)
|
||||
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.1")
|
||||
class FaceOffInvocation(BaseInvocation):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image for face detection")
|
||||
face_id: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="The face ID to process, numbered from 0. Multiple faces not supported. Find a face's ID with FaceIdentifier node.",
|
||||
)
|
||||
minimum_confidence: float = InputField(
|
||||
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
|
||||
)
|
||||
x_offset: float = InputField(default=0.0, description="X-axis offset of the mask")
|
||||
y_offset: float = InputField(default=0.0, description="Y-axis offset of the mask")
|
||||
padding: int = InputField(default=0, description="All-axis padding around the mask in pixels")
|
||||
chunk: bool = InputField(
|
||||
default=False,
|
||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||
)
|
||||
|
||||
def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]:
|
||||
all_faces = get_faces_list(
|
||||
context=context,
|
||||
image=image,
|
||||
should_chunk=self.chunk,
|
||||
minimum_confidence=self.minimum_confidence,
|
||||
x_offset=self.x_offset,
|
||||
y_offset=self.y_offset,
|
||||
draw_mesh=True,
|
||||
)
|
||||
|
||||
if len(all_faces) == 0:
|
||||
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||
return None
|
||||
|
||||
if self.face_id > len(all_faces) - 1:
|
||||
context.services.logger.warning(
|
||||
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
||||
)
|
||||
return None
|
||||
|
||||
face_data = extract_face(context=context, image=image, face=all_faces[self.face_id], padding=self.padding)
|
||||
# Convert the input image to RGBA mode to ensure it has an alpha channel.
|
||||
face_data["bounded_image"] = face_data["bounded_image"].convert("RGBA")
|
||||
|
||||
return face_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
result = self.faceoff(context=context, image=image)
|
||||
|
||||
if result is None:
|
||||
result_image = image
|
||||
result_mask = create_white_image(*image.size)
|
||||
x = 0
|
||||
y = 0
|
||||
else:
|
||||
result_image = result["bounded_image"]
|
||||
result_mask = result["bounded_mask"]
|
||||
x = result["x_min"]
|
||||
y = result["y_min"]
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=result_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
image=result_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
output = FaceOffOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
mask=ImageField(image_name=mask_dto.image_name),
|
||||
x=x,
|
||||
y=y,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.1")
|
||||
class FaceMaskInvocation(BaseInvocation):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
face_ids: str = InputField(
|
||||
default="",
|
||||
description="Comma-separated list of face ids to mask eg '0,2,7'. Numbered from 0. Leave empty to mask all. Find face IDs with FaceIdentifier node.",
|
||||
)
|
||||
minimum_confidence: float = InputField(
|
||||
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
|
||||
)
|
||||
x_offset: float = InputField(default=0.0, description="Offset for the X-axis of the face mask")
|
||||
y_offset: float = InputField(default=0.0, description="Offset for the Y-axis of the face mask")
|
||||
chunk: bool = InputField(
|
||||
default=False,
|
||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||
)
|
||||
invert_mask: bool = InputField(default=False, description="Toggle to invert the mask")
|
||||
|
||||
@validator("face_ids")
|
||||
def validate_comma_separated_ints(cls, v) -> str:
|
||||
comma_separated_ints_regex = re.compile(r"^\d*(,\d+)*$")
|
||||
if comma_separated_ints_regex.match(v) is None:
|
||||
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
|
||||
return v
|
||||
|
||||
def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult:
|
||||
all_faces = get_faces_list(
|
||||
context=context,
|
||||
image=image,
|
||||
should_chunk=self.chunk,
|
||||
minimum_confidence=self.minimum_confidence,
|
||||
x_offset=self.x_offset,
|
||||
y_offset=self.y_offset,
|
||||
draw_mesh=True,
|
||||
)
|
||||
|
||||
mask_pil = create_white_image(*image.size)
|
||||
|
||||
id_range = list(range(0, len(all_faces)))
|
||||
ids_to_extract = id_range
|
||||
if self.face_ids != "":
|
||||
parsed_face_ids = [int(id) for id in self.face_ids.split(",")]
|
||||
# get requested face_ids that are in range
|
||||
intersected_face_ids = set(parsed_face_ids) & set(id_range)
|
||||
|
||||
if len(intersected_face_ids) == 0:
|
||||
id_range_str = ",".join([str(id) for id in id_range])
|
||||
context.services.logger.warning(
|
||||
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
||||
)
|
||||
return FaceMaskResult(
|
||||
image=image, # original image
|
||||
mask=mask_pil, # white mask
|
||||
)
|
||||
|
||||
ids_to_extract = list(intersected_face_ids)
|
||||
|
||||
for face_id in ids_to_extract:
|
||||
face_data = extract_face(context=context, image=image, face=all_faces[face_id], padding=0)
|
||||
face_mask_pil = face_data["bounded_mask"]
|
||||
x_min = face_data["x_min"]
|
||||
y_min = face_data["y_min"]
|
||||
x_max = face_data["x_max"]
|
||||
y_max = face_data["y_max"]
|
||||
|
||||
mask_pil.paste(
|
||||
create_black_image(x_max - x_min, y_max - y_min),
|
||||
box=(x_min, y_min),
|
||||
mask=ImageOps.invert(face_mask_pil),
|
||||
)
|
||||
|
||||
if self.invert_mask:
|
||||
mask_pil = ImageOps.invert(mask_pil)
|
||||
|
||||
# Create an RGBA image with transparency
|
||||
image = image.convert("RGBA")
|
||||
|
||||
return FaceMaskResult(
|
||||
image=image,
|
||||
mask=mask_pil,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
result = self.facemask(context=context, image=image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=result["image"],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
image=result["mask"],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
output = FaceMaskOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
mask=ImageField(image_name=mask_dto.image_name),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.1"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
minimum_confidence: float = InputField(
|
||||
default=0.5, description="Minimum confidence for face detection (lower if detection is failing)"
|
||||
)
|
||||
chunk: bool = InputField(
|
||||
default=False,
|
||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||
)
|
||||
|
||||
def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType:
|
||||
image = image.copy()
|
||||
|
||||
all_faces = get_faces_list(
|
||||
context=context,
|
||||
image=image,
|
||||
should_chunk=self.chunk,
|
||||
minimum_confidence=self.minimum_confidence,
|
||||
x_offset=0,
|
||||
y_offset=0,
|
||||
draw_mesh=False,
|
||||
)
|
||||
|
||||
# Note - font may be found either in the repo if running an editable install, or in the venv if running a package install
|
||||
font_path = [x for x in [Path(y, "inter/Inter-Regular.ttf") for y in font_assets.__path__] if x.exists()]
|
||||
font = ImageFont.truetype(font_path[0].as_posix(), FONT_SIZE)
|
||||
|
||||
# Paste face IDs on the output image
|
||||
draw = ImageDraw.Draw(image)
|
||||
for face in all_faces:
|
||||
x_coord = face["x_center"]
|
||||
y_coord = face["y_center"]
|
||||
text = str(face["face_id"])
|
||||
# get bbox of the text so we can center the id on the face
|
||||
_, _, bbox_w, bbox_h = draw.textbbox(xy=(0, 0), text=text, font=font, stroke_width=FONT_STROKE_WIDTH)
|
||||
x = x_coord - bbox_w / 2
|
||||
y = y_coord - bbox_h / 2
|
||||
draw.text(
|
||||
xy=(x, y),
|
||||
text=str(text),
|
||||
fill=(255, 255, 255, 255),
|
||||
font=font,
|
||||
stroke_width=FONT_STROKE_WIDTH,
|
||||
stroke_fill=(0, 0, 0, 255),
|
||||
)
|
||||
|
||||
# Create an RGBA image with transparency
|
||||
image = image.convert("RGBA")
|
||||
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
result_image = self.faceidentifier(context=context, image=image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=result_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -8,12 +8,12 @@ import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
@@ -965,3 +965,44 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"save_image",
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation):
|
||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
metadata: CoreMetadata = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.core_metadata,
|
||||
ui_hidden=True,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
board_id=self.board.board_id if self.board else None,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.dict() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -269,7 +269,7 @@ class LaMaInfillInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
class CV2InfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
|
||||
|
||||
class IPAdapterModelField(BaseModel):
|
||||
@@ -58,9 +58,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
input=Input.Direct,
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@@ -31,20 +33,21 @@ from invokeai.app.invocations.primitives import (
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.model_manager.lora import ModelPatcher
|
||||
from ...backend.model_manager.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
@@ -129,7 +132,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -162,7 +165,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
orig_scheduler_info = context.services.model_loader.get_model(
|
||||
**scheduler_info.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -194,7 +197,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.3.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@@ -208,7 +211,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
@@ -218,16 +221,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
ip_adapter: Optional[IPAdapterField] = InputField(
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||
)
|
||||
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
|
||||
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@@ -356,7 +361,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
@@ -403,52 +408,150 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_ip_adapter_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[IPAdapterField],
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
conditioning_data: ConditioningData,
|
||||
unet: UNet2DConditionModel,
|
||||
exit_stack: ExitStack,
|
||||
) -> Optional[IPAdapterData]:
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
"""
|
||||
if ip_adapter is None:
|
||||
return None
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
model_name=ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=ip_adapter.image_encoder_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||
if not isinstance(ip_adapter, list):
|
||||
ip_adapter = [ip_adapter]
|
||||
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=ip_adapter.ip_adapter_model.base_model,
|
||||
if len(ip_adapter) == 0:
|
||||
return None
|
||||
|
||||
ip_adapter_data_list = []
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_loader.get_model(
|
||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
)
|
||||
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
)
|
||||
conditioning_data.ip_adapter_conditioning.append(
|
||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||
)
|
||||
|
||||
ip_adapter_data_list.append(
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=single_ip_adapter.weight,
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
)
|
||||
)
|
||||
|
||||
return IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=ip_adapter.weight,
|
||||
begin_step_percent=ip_adapter.begin_step_percent,
|
||||
end_step_percent=ip_adapter.end_step_percent,
|
||||
)
|
||||
return ip_adapter_data_list
|
||||
|
||||
def run_t2i_adapters(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
latents_shape: list[int],
|
||||
do_classifier_free_guidance: bool,
|
||||
) -> Optional[list[T2IAdapterData]]:
|
||||
if t2i_adapter is None:
|
||||
return None
|
||||
|
||||
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
|
||||
if isinstance(t2i_adapter, T2IAdapterField):
|
||||
t2i_adapter = [t2i_adapter]
|
||||
|
||||
if len(t2i_adapter) == 0:
|
||||
return None
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_loader.get_model(
|
||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||
model_type=ModelType.T2IAdapter,
|
||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||
)
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
|
||||
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
|
||||
# this PR: https://github.com/huggingface/diffusers/pull/5134.
|
||||
total_downscale_factor = total_downscale_factor // 2
|
||||
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||
# T2I-Adapter model.
|
||||
#
|
||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||
# of the same requirements (e.g. preserving binary masks during resize).
|
||||
t2i_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=t2i_input_width,
|
||||
height=t2i_input_height,
|
||||
num_channels=t2i_adapter_model.config.in_channels,
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
|
||||
adapter_state = t2i_adapter_model(t2i_image)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
for idx, value in enumerate(adapter_state):
|
||||
adapter_state[idx] = torch.cat([value] * 2, dim=0)
|
||||
|
||||
t2i_adapter_data.append(
|
||||
T2IAdapterData(
|
||||
adapter_state=adapter_state,
|
||||
weight=t2i_adapter_field.weight,
|
||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||
)
|
||||
)
|
||||
|
||||
return t2i_adapter_data
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@@ -521,6 +624,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
t2i_adapter_data = self.run_t2i_adapters(
|
||||
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
@@ -530,7 +639,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_loader.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
@@ -538,7 +647,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
unet_info = context.services.model_loader.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -579,7 +688,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
conditioning_data=conditioning_data,
|
||||
unet=unet,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
@@ -601,8 +709,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masked_latents=masked_latents,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data, # list[ControlNetData],
|
||||
ip_adapter_data=ip_adapter_data, # IPAdapterData,
|
||||
control_data=controlnet_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
@@ -643,7 +752,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -857,8 +966,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
@@ -869,7 +977,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
context=context,
|
||||
)
|
||||
@@ -885,6 +993,18 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
|
||||
@@ -54,17 +54,38 @@ class DivideInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
||||
@invocation(
|
||||
"rand_int",
|
||||
title="Random Integer",
|
||||
tags=["math", "random"],
|
||||
category="math",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
low: int = InputField(default=0, description="The inclusive low value")
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
||||
|
||||
@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0")
|
||||
class RandomFloatInvocation(BaseInvocation):
|
||||
"""Outputs a single random float"""
|
||||
|
||||
low: float = InputField(default=0.0, description=FieldDescriptions.inclusive_low)
|
||||
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
|
||||
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
random_float = np.random.uniform(self.low, self.high)
|
||||
rounded_float = round(random_float, self.decimals)
|
||||
return FloatOutput(value=rounded_float)
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_to_int",
|
||||
title="Float To Integer",
|
||||
|
||||
@@ -12,7 +12,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
from ...version import __version__
|
||||
@@ -25,6 +28,18 @@ class LoRAMetadataField(BaseModelExcludeNull):
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
|
||||
class IPAdapterMetadataField(BaseModelExcludeNull):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
weight: float = Field(description="The weight of the IP-Adapter model")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
|
||||
|
||||
class CoreMetadata(BaseModelExcludeNull):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
@@ -42,11 +57,14 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
|
||||
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Optional[VAEModelField] = Field(
|
||||
default=None,
|
||||
@@ -116,11 +134,14 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
cfg_scale: float = InputField(description="The classifier-free guidance scale parameter")
|
||||
steps: int = InputField(description="The number of steps used for inference")
|
||||
scheduler: str = InputField(description="The scheduler used for inference")
|
||||
clip_skip: int = InputField(
|
||||
clip_skip: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = InputField(description="The main model used for inference")
|
||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
|
||||
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
|
||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||
strength: Optional[float] = InputField(
|
||||
default=None,
|
||||
|
||||
@@ -3,7 +3,8 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -19,9 +20,7 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
key: str = Field(description="Unique ID for model")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
@@ -61,16 +60,13 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
key: str = Field(description="Unique ID of the model")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
key: str = Field(description="Unique ID for model")
|
||||
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
@@ -81,20 +77,15 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
"""Load a main model, outputting its submodels."""
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
@@ -103,7 +94,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
@@ -112,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
@@ -125,30 +116,22 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
@@ -156,9 +139,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
@@ -167,7 +148,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
"""Model loader output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
@@ -187,24 +168,20 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
"""Load a LoRA model."""
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unknown lora: {key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@@ -212,9 +189,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -224,9 +199,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -237,7 +210,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
"""SDXL LoRA Loader Output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
@@ -261,27 +234,22 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
"""Load an SDXL LoRA."""
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
key = self.lora.key
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unknown lora name: {key}!")
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to unet')
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
@@ -289,9 +257,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -301,9 +267,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -313,9 +277,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
key=key,
|
||||
submodel=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
@@ -325,10 +287,9 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
"""Vae model field."""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
key: str = Field(description="Unique ID for VAE model")
|
||||
|
||||
|
||||
@invocation_output("vae_loader_output")
|
||||
@@ -340,29 +301,22 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput."""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
"""Load a VAE model."""
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
if not context.services.model_record_store.model_exists(key):
|
||||
raise Exception(f"Unkown vae name: {key}!")
|
||||
return VaeLoaderOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
key=key,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -370,7 +324,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
"""Modified Seamless Model output."""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
@@ -390,6 +344,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||
"""Apply seamless transformation."""
|
||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||
unet = copy.deepcopy(self.unet)
|
||||
vae = copy.deepcopy(self.vae)
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.model_manager.lora import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
@@ -62,15 +62,15 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_loader.get_model(
|
||||
**self.clip.tokenizer.dict(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_loader.get_model(
|
||||
**self.clip.text_encoder.dict(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
@@ -81,7 +81,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_loader.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
@@ -166,7 +166,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
ui_type=UIType.Float,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
@@ -179,7 +178,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
ui_type=UIType.Control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
@@ -256,12 +254,12 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
unet_info = context.services.model_loader.get_model(**self.unet.unet.dict())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
# loras = [(stack.enter_context(context.services.model_loader.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
@@ -347,7 +345,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_loader.get_model(
|
||||
**self.vae.vae.dict(),
|
||||
)
|
||||
|
||||
@@ -420,7 +418,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@@ -428,7 +426,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
@@ -437,7 +435,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
@@ -446,7 +444,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
|
||||
@@ -226,6 +226,12 @@ class ImageField(BaseModel):
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class BoardField(BaseModel):
|
||||
"""A board primitive field"""
|
||||
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
@invocation_output("image_output")
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
@@ -10,7 +10,14 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||
|
||||
|
||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
||||
@invocation(
|
||||
"dynamic_prompt",
|
||||
title="Dynamic Prompt",
|
||||
tags=["prompt", "collection"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from ...backend.model_manager import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -48,7 +48,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@@ -137,7 +137,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.services.model_record_store.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
||||
83
invokeai/app/invocations/t2i_adapter.py
Normal file
83
invokeai/app/invocations/t2i_adapter.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
@invocation_output("t2i_adapter_output")
|
||||
class T2IAdapterOutput(BaseInvocationOutput):
|
||||
t2i_adapter: T2IAdapterField = OutputField(description=FieldDescriptions.t2i_adapter, title="T2I Adapter")
|
||||
|
||||
|
||||
@invocation(
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
|
||||
)
|
||||
class T2IAdapterInvocation(BaseInvocation):
|
||||
"""Collects T2I-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(
|
||||
default="just_resize",
|
||||
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
|
||||
return T2IAdapterOutput(
|
||||
t2i_adapter=T2IAdapterField(
|
||||
image=self.image,
|
||||
t2i_adapter_model=self.t2i_adapter_model,
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
resize_mode=self.resize_mode,
|
||||
)
|
||||
)
|
||||
@@ -4,12 +4,14 @@ from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
@@ -22,13 +24,19 @@ ESRGAN_MODELS = Literal[
|
||||
"RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||
tile_size: int = InputField(
|
||||
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
@@ -86,9 +94,11 @@ class ESRGANInvocation(BaseInvocation):
|
||||
model_path=str(models_path / esrgan_model_path),
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
|
||||
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||
@@ -99,6 +109,10 @@ class ESRGANInvocation(BaseInvocation):
|
||||
# back to PIL
|
||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
|
||||
@@ -53,24 +53,20 @@ class BoardImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -8,6 +7,7 @@ from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
@@ -87,24 +87,20 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
@@ -174,7 +170,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
board_id = uuid_string()
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
|
||||
@@ -16,7 +16,7 @@ import pydoc
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic import BaseSettings
|
||||
@@ -25,6 +25,7 @@ from pydantic import BaseSettings
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
@@ -39,10 +40,10 @@ class InvokeAISettings(BaseSettings):
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
|
||||
initconf: ClassVar[DictConfig] = None
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list = sys.argv[1:]):
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
@@ -83,7 +84,8 @@ class InvokeAISettings(BaseSettings):
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||
env_prefix = getattr(cls.Config, "env_prefix", None)
|
||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
@@ -116,8 +118,8 @@ class InvokeAISettings(BaseSettings):
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(self)
|
||||
def cmd_name(cls, command_field: str = "type") -> str:
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
@@ -133,30 +135,16 @@ class InvokeAISettings(BaseSettings):
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self) -> List[str]:
|
||||
def _excluded(cls) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(self) -> List[str]:
|
||||
def _excluded_from_yaml(cls) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
]
|
||||
|
||||
class Config:
|
||||
@@ -229,9 +217,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
"""Workaround for argparse type checking."""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
|
||||
@@ -171,6 +171,7 @@ two configs are kept in separate sections of the config file:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
@@ -182,7 +183,9 @@ from .base import InvokeAISettings
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
DEFAULT_MAX_DISK_CACHE = 20 # gigs, enough for three sdxl models, or 6 sd-1 models
|
||||
DEFAULT_RAM_CACHE = 7.5
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
@@ -194,8 +197,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
"""
|
||||
|
||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||
singleton_init: ClassVar[Dict] = None
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict]] = None
|
||||
|
||||
# fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
@@ -217,11 +220,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# PATHS
|
||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
autoimport_dir : Optional[Path] = Field(default=None, description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||
model_config_db : Union[Path, Literal['auto'], None] = Field(default=None, description='Path to a sqlite .db file or .yaml file for storing model config records; "auto" will reuse the main sqlite db', category='Paths')
|
||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||
@@ -234,29 +234,36 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", category="Logging")
|
||||
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", category="Development")
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching", category="Model Cache", )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage", category="Model Cache", )
|
||||
disk : float = Field(default=DEFAULT_MAX_DISK_CACHE, ge=0, description="Maximum size (in GB) for the disk-based diffusers model conversion cache", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
# DEVICE
|
||||
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", category="Device", )
|
||||
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||
|
||||
# NODES
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", category="Nodes")
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Nodes", )
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
@@ -265,14 +272,19 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
env_prefix = "INVOKEAI"
|
||||
|
||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||
def parse_args(self, argv: Optional[list[str]] = None, conf: Optional[DictConfig] = None, clobber=False):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
@@ -283,12 +295,16 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
loaded_conf = None
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
if isinstance(loaded_conf, DictConfig):
|
||||
InvokeAISettings.initconf = loaded_conf
|
||||
else:
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
super().parse_args(argv)
|
||||
@@ -300,9 +316,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
"""
|
||||
"""This returns a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) is not cls
|
||||
@@ -312,6 +326,29 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(cls) -> List[str]:
|
||||
el = super()._excluded_from_yaml()
|
||||
el.extend(
|
||||
[
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
"conf_path",
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
]
|
||||
)
|
||||
return el
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""
|
||||
@@ -376,13 +413,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def autoconvert_path(self) -> Path:
|
||||
"""
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
"""
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self) -> bool:
|
||||
@@ -405,9 +435,13 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> float:
|
||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def conversion_cache_size(self) -> float:
|
||||
return self.disk
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
return self.max_vram_cache_size or self.vram
|
||||
@@ -435,9 +469,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -10,57 +10,58 @@ default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
graph = Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
)
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name="t2i",
|
||||
description="Converts text to an image",
|
||||
graph=Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
),
|
||||
graph=graph,
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
|
||||
205
invokeai/app/services/download_manager.py
Normal file
205
invokeai/app/services/download_manager.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Model download service.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
|
||||
from invokeai.backend.model_manager.download import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueBase,
|
||||
ModelDownloadQueue,
|
||||
ModelSourceMetadata,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .events import EventServiceBase
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: Optional[bool] = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download - URL, repo_id or local Path
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
:param event_handler: Callable that receives a DownloadJobBase and acts on it.
|
||||
:returns job id: The numeric ID of the DownloadJobBase object for this task.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Submit a download job.
|
||||
|
||||
:param job: A DownloadJobBase
|
||||
:param start: Immediately start job [True]
|
||||
|
||||
After execution, `job.id` will be set to a non-negative value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""
|
||||
List active DownloadJobBases.
|
||||
|
||||
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""
|
||||
Return the DownloadJobBase corresponding to the string ID.
|
||||
|
||||
:param id: ID of the DownloadJobBase.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_all_jobs(self):
|
||||
"""Enqueue all idle and paused jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_all_jobs(self):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
|
||||
class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
_event_bus: Optional["EventServiceBase"] = None
|
||||
_queue: DownloadQueueBase
|
||||
|
||||
def __init__(self, event_bus: Optional["EventServiceBase"] = None, **kwargs):
|
||||
"""
|
||||
Initialize new DownloadQueueService object.
|
||||
|
||||
:param event_bus: EventServiceBase object for reporting progress.
|
||||
:param **kwargs: Any of the arguments taken by invokeai.backend.model_manager.download.DownloadQueue.
|
||||
e.g. `max_parallel_dl`.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._queue = ModelDownloadQueue(**kwargs)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: Optional[bool] = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
event_handlers = event_handlers or []
|
||||
if self._event_bus:
|
||||
event_handlers = [*event_handlers, self._event_bus.emit_model_event]
|
||||
return self._queue.create_download_job(
|
||||
source=source,
|
||||
destdir=destdir,
|
||||
filename=filename,
|
||||
start=start,
|
||||
access_token=access_token,
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: bool = True,
|
||||
):
|
||||
return self._queue.submit_download_job(job, start)
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]: # noqa D102
|
||||
return self._queue.list_jobs()
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase: # noqa D102
|
||||
return self._queue.id_to_job(id)
|
||||
|
||||
def start_all_jobs(self): # noqa D102
|
||||
return self._queue.start_all_jobs()
|
||||
|
||||
def pause_all_jobs(self): # noqa D102
|
||||
return self._queue.pause_all_jobs()
|
||||
|
||||
def cancel_all_jobs(self): # noqa D102
|
||||
return self._queue.cancel_all_jobs()
|
||||
|
||||
def prune_jobs(self): # noqa D102
|
||||
return self._queue.prune_jobs()
|
||||
|
||||
def start_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.start_job(job)
|
||||
|
||||
def pause_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.pause_job(job)
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.cancel_job(job)
|
||||
|
||||
def join(self): # noqa D102
|
||||
return self._queue.join()
|
||||
@@ -3,22 +3,32 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||
from invokeai.app.services.model_record_service import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.loader import ModelInfo
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
queue_event: str = "queue_event"
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event."""
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
@@ -26,6 +36,9 @@ class EventServiceBase:
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
@@ -35,11 +48,14 @@ class EventServiceBase:
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
node_id=node.get("id"),
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
@@ -50,15 +66,21 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -68,6 +90,9 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
@@ -75,9 +100,12 @@ class EventServiceBase:
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -86,63 +114,83 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
||||
def emit_invocation_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
def emit_graph_execution_complete(
|
||||
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||
) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_key: str,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
@@ -152,14 +200,20 @@ class EventServiceBase:
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
@@ -168,18 +222,94 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
self,
|
||||
session_queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload=dict(
|
||||
queue_id=queue_status.queue_id,
|
||||
queue_item=dict(
|
||||
queue_id=session_queue_item.queue_id,
|
||||
item_id=session_queue_item.item_id,
|
||||
status=session_queue_item.status,
|
||||
batch_id=session_queue_item.batch_id,
|
||||
session_id=session_queue_item.session_id,
|
||||
error=session_queue_item.error,
|
||||
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
),
|
||||
batch_status=batch_status.dict(),
|
||||
queue_status=queue_status.dict(),
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload=dict(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload=dict(queue_id=queue_id),
|
||||
)
|
||||
|
||||
def emit_model_event(self, job: DownloadJobBase) -> None:
|
||||
"""Emit event when the status of a download/install job changes."""
|
||||
self.dispatch( # use dispatch() directly here because we are not a session event.
|
||||
event_name="model_event", payload=dict(job=job)
|
||||
)
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ..invocations import * # noqa: F401 F403
|
||||
from ..invocations.baseinvocation import (
|
||||
@@ -116,6 +117,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if from_type is int and to_type is float:
|
||||
return True
|
||||
|
||||
# allow int|float -> str, pydantic will cast for us
|
||||
if (from_type is int or from_type is float) and to_type is str:
|
||||
return True
|
||||
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
@@ -137,19 +142,43 @@ def are_connections_compatible(
|
||||
return are_connection_types_compatible(from_node_field, to_node_field)
|
||||
|
||||
|
||||
class NodeAlreadyInGraphError(Exception):
|
||||
class NodeAlreadyInGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEdgeError(Exception):
|
||||
class InvalidEdgeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
class NodeNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeAlreadyExecutedError(Exception):
|
||||
class NodeAlreadyExecutedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateNodeIdError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeFieldNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class NodeIdMismatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidSubGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class CyclicalGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownGraphValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -227,7 +256,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -307,53 +336,108 @@ class Graph(BaseModel):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Validates the graph."""
|
||||
def validate_self(self) -> None:
|
||||
"""
|
||||
Validates the graph.
|
||||
|
||||
Raises an exception if the graph is invalid:
|
||||
- `DuplicateNodeIdError`
|
||||
- `NodeIdMismatchError`
|
||||
- `InvalidSubGraphError`
|
||||
- `NodeNotFoundError`
|
||||
- `NodeFieldNotFoundError`
|
||||
- `CyclicalGraphError`
|
||||
- `InvalidEdgeError`
|
||||
"""
|
||||
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in self.nodes.values()]
|
||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
|
||||
# Validate that all node ids match the keys in the nodes dict
|
||||
for k, v in self.nodes.items():
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
|
||||
# Validate all subgraphs
|
||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||
if not gn.graph.is_valid():
|
||||
return False
|
||||
try:
|
||||
gn.graph.validate_self()
|
||||
except Exception as e:
|
||||
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
|
||||
|
||||
# Validate all edges reference nodes in the graph
|
||||
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
|
||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||
return False
|
||||
# Validate that all edges match nodes and fields in the graph
|
||||
for edge in self.edges:
|
||||
source_node = self.nodes.get(edge.source.node_id, None)
|
||||
if source_node is None:
|
||||
raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
||||
|
||||
destination_node = self.nodes.get(edge.destination.node_id, None)
|
||||
if destination_node is None:
|
||||
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
||||
|
||||
# output fields are not on the node object directly, they are on the output type
|
||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||
)
|
||||
|
||||
# input fields are on the node
|
||||
if edge.destination.field not in destination_node.__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||
)
|
||||
|
||||
# Validate there are no cycles
|
||||
g = self.nx_graph_flat()
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
raise CyclicalGraphError("Graph contains cycles")
|
||||
|
||||
# Validate all edge connections are valid
|
||||
if not all(
|
||||
(
|
||||
are_connections_compatible(
|
||||
self.get_node(e.source.node_id),
|
||||
e.source.field,
|
||||
self.get_node(e.destination.node_id),
|
||||
e.destination.field,
|
||||
for e in self.edges:
|
||||
if not are_connections_compatible(
|
||||
self.get_node(e.source.node_id),
|
||||
e.source.field,
|
||||
self.get_node(e.destination.node_id),
|
||||
e.destination.field,
|
||||
):
|
||||
raise InvalidEdgeError(
|
||||
f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}"
|
||||
)
|
||||
for e in self.edges
|
||||
)
|
||||
|
||||
# Validate all iterators & collectors
|
||||
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
|
||||
for n in self.nodes.values():
|
||||
if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id):
|
||||
raise InvalidEdgeError(f"Invalid iterator node {n.id}")
|
||||
if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id):
|
||||
raise InvalidEdgeError(f"Invalid collector node {n.id}")
|
||||
|
||||
return None
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Checks if the graph is valid.
|
||||
|
||||
Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error).
|
||||
"""
|
||||
try:
|
||||
self.validate_self()
|
||||
return True
|
||||
except (
|
||||
DuplicateNodeIdError,
|
||||
NodeIdMismatchError,
|
||||
InvalidSubGraphError,
|
||||
NodeNotFoundError,
|
||||
NodeFieldNotFoundError,
|
||||
CyclicalGraphError,
|
||||
InvalidEdgeError,
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate all iterators
|
||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||
if not all(
|
||||
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
|
||||
):
|
||||
return False
|
||||
|
||||
# Validate all collectors
|
||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||
if not all(
|
||||
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
|
||||
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
@@ -697,8 +781,7 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
@@ -735,6 +818,12 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
@@ -847,7 +936,7 @@ class GraphExecutionState(BaseModel):
|
||||
new_node = copy.deepcopy(node)
|
||||
|
||||
# Create the node id (use a random uuid)
|
||||
new_node.id = str(uuid.uuid4())
|
||||
new_node.id = uuid_string()
|
||||
|
||||
# Set the iteration index for iteration invocations
|
||||
if isinstance(new_node, IterateInvocation):
|
||||
@@ -1082,7 +1171,7 @@ class ExposedNodeOutput(BaseModel):
|
||||
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
|
||||
@@ -9,6 +9,7 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
@@ -79,6 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
__compress_level: int
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__cache = dict()
|
||||
@@ -87,7 +89,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||
|
||||
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
|
||||
# Validate required output folders at launch
|
||||
self.__validate_storage_folders()
|
||||
|
||||
@@ -134,7 +136,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
if original_workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||
|
||||
@@ -148,24 +148,20 @@ class ImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
self._lock = lock
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
@@ -588,7 +584,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY images.created_at DESC
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
@@ -38,6 +38,29 @@ if TYPE_CHECKING:
|
||||
class ImageServiceABC(ABC):
|
||||
"""High-level service for image management."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||
"""Register a callback for when an image is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an image is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: ImageDTO) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
@@ -161,6 +184,7 @@ class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
super().__init__()
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
@@ -217,6 +241,7 @@ class ImageService(ImageServiceABC):
|
||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
@@ -235,7 +260,9 @@ class ImageService(ImageServiceABC):
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.image_records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
image_dto = self.get_dto(image_name)
|
||||
self._on_changed(image_dto)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
@@ -374,6 +401,7 @@ class ImageService(ImageServiceABC):
|
||||
try:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image record")
|
||||
raise
|
||||
@@ -390,6 +418,8 @@ class ImageService(ImageServiceABC):
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete_many(image_names)
|
||||
for image_name in image_names:
|
||||
self._on_deleted(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
raise
|
||||
@@ -406,6 +436,7 @@ class ImageService(ImageServiceABC):
|
||||
count = len(image_names)
|
||||
for image_name in image_names:
|
||||
self._services.image_files.delete(image_name)
|
||||
self._on_deleted(image_name)
|
||||
return count
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error("Failed to delete image records")
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
|
||||
|
||||
class InvocationCacheBase(ABC):
|
||||
"""
|
||||
Base class for invocation caches.
|
||||
When an invocation is executed, it is hashed and its output stored in the cache.
|
||||
When new invocations are executed, if they are flagged with `use_cache`, they
|
||||
will attempt to pull their value from the cache before executing.
|
||||
|
||||
Implementations should register for the `on_deleted` event of the `images` and `latents`
|
||||
services, and delete any cached outputs that reference the deleted image or latent.
|
||||
|
||||
See the memory implementation for an example.
|
||||
|
||||
Implementations should respect the `node_cache_size` configuration value, and skip all
|
||||
cache logic if the value is set to 0.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
"""Retrieves an invocation output from the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||
"""Stores an invocation output in the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
"""Deletes an invocation output from the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disable(self) -> None:
|
||||
"""Disables the cache, overriding the max cache size"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enable(self) -> None:
|
||||
"""Enables the cache, letting the the max cache size take effect"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
"""Returns the status of the cache"""
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvocationCacheStatus(BaseModel):
|
||||
size: int = Field(description="The current size of the invocation cache")
|
||||
hits: int = Field(description="The number of cache hits")
|
||||
misses: int = Field(description="The number of cache misses")
|
||||
enabled: bool = Field(description="Whether the invocation cache is enabled")
|
||||
max_size: int = Field(description="The maximum size of the invocation cache")
|
||||
@@ -0,0 +1,126 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class CachedItem:
|
||||
invocation_output: BaseInvocationOutput = field(compare=False)
|
||||
invocation_output_json: str = field(compare=False)
|
||||
|
||||
|
||||
class MemoryInvocationCache(InvocationCacheBase):
|
||||
_cache: OrderedDict[Union[int, str], CachedItem]
|
||||
_max_cache_size: int
|
||||
_disabled: bool
|
||||
_hits: int
|
||||
_misses: int
|
||||
_invoker: Invoker
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, max_cache_size: int = 0) -> None:
|
||||
self._cache = OrderedDict()
|
||||
self._max_cache_size = max_cache_size
|
||||
self._disabled = False
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._lock = Lock()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._invoker.services.images.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0 or self._disabled:
|
||||
return None
|
||||
item = self._cache.get(key, None)
|
||||
if item is not None:
|
||||
self._hits += 1
|
||||
self._cache.move_to_end(key)
|
||||
return item.invocation_output
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0 or self._disabled or key in self._cache:
|
||||
return
|
||||
# If the cache is full, we need to remove the least used
|
||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
number_to_delete = min(number_to_delete, len(self._cache))
|
||||
for _ in range(number_to_delete):
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
def _delete(self, key: Union[int, str]) -> None:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
|
||||
def delete(self, key: Union[int, str]) -> None:
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._cache.clear()
|
||||
self._misses = 0
|
||||
self._hits = 0
|
||||
|
||||
@staticmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
return hash(invocation.json(exclude={"id"}))
|
||||
|
||||
def disable(self) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._disabled = True
|
||||
|
||||
def enable(self) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._disabled = False
|
||||
|
||||
def get_status(self) -> InvocationCacheStatus:
|
||||
with self._lock:
|
||||
return InvocationCacheStatus(
|
||||
hits=self._hits,
|
||||
misses=self._misses,
|
||||
enabled=not self._disabled and self._max_cache_size > 0,
|
||||
size=len(self._cache),
|
||||
max_size=self._max_cache_size,
|
||||
)
|
||||
|
||||
def _delete_by_match(self, to_match: str) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
keys_to_delete = set()
|
||||
for key, cached_item in self._cache.items():
|
||||
if to_match in cached_item.invocation_output_json:
|
||||
keys_to_delete.add(key)
|
||||
if not keys_to_delete:
|
||||
return
|
||||
for key in keys_to_delete:
|
||||
self._delete(key)
|
||||
self._invoker.services.logger.debug(
|
||||
f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}"
|
||||
)
|
||||
@@ -11,6 +11,13 @@ from pydantic import BaseModel, Field
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: int = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -9,15 +9,21 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_install_service import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@@ -28,15 +34,21 @@ class InvocationServices:
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||
images: "ImageServiceABC"
|
||||
latents: "LatentsStorageBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_record_store: "ModelRecordServiceBase"
|
||||
model_loader: "ModelLoadServiceBase"
|
||||
model_installer: "ModelInstallServiceBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
session_queue: "SessionQueueBase"
|
||||
session_processor: "SessionProcessorBase"
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,19 +56,24 @@ class InvocationServices:
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||
images: "ImageServiceABC",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_record_store: "ModelRecordServiceBase",
|
||||
model_loader: "ModelLoadServiceBase",
|
||||
model_installer: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
@@ -64,7 +81,13 @@ class InvocationServices:
|
||||
self.images = images
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.model_record_store = model_record_store
|
||||
self.model_loader = model_loader
|
||||
self.model_installer = model_installer
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
|
||||
@@ -38,12 +38,12 @@ import psutil
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .graph import GraphExecutionState
|
||||
from .item_storage import ItemStorageABC
|
||||
from .model_manager_service import ModelManagerService
|
||||
from .model_loader_service import ModelLoadServiceBase
|
||||
|
||||
# size of GIG in bytes
|
||||
GIG = 1073741824
|
||||
@@ -174,13 +174,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_id: str
|
||||
start_time: float
|
||||
ram_used: int
|
||||
model_manager: ModelManagerService
|
||||
model_loader: ModelLoadServiceBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
model_loader: ModelLoadServiceBase,
|
||||
collector: "InvocationStatsServiceBase",
|
||||
):
|
||||
"""Initialize statistics for this run."""
|
||||
@@ -189,15 +189,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self.graph_id = graph_id
|
||||
self.start_time = 0.0
|
||||
self.ram_used = 0
|
||||
self.model_manager = model_manager
|
||||
self.model_loader = model_loader
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.ram_used = psutil.Process().memory_info().rss
|
||||
if self.model_manager:
|
||||
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
if self.model_loader:
|
||||
self.model_loader.collect_cache_stats(self.collector._cache_stats[self.graph_id])
|
||||
|
||||
def __exit__(self, *args):
|
||||
"""Called on exit from the context."""
|
||||
@@ -208,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self.collector.update_invocation_stats(
|
||||
graph_id=self.graph_id,
|
||||
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||
invocation_type=self.invocation.type,
|
||||
time_used=time.time() - self.start_time,
|
||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||
)
|
||||
@@ -217,12 +217,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
model_manager: ModelManagerService,
|
||||
model_loader: ModelLoadServiceBase,
|
||||
) -> StatsContext:
|
||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||
self._stats[graph_execution_state_id] = NodeLog()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||
return self.StatsContext(invocation, graph_execution_state_id, model_loader, self)
|
||||
|
||||
def reset_all_stats(self):
|
||||
"""Zero all statistics"""
|
||||
|
||||
@@ -17,7 +17,14 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||
def invoke(
|
||||
self,
|
||||
session_queue_id: str,
|
||||
session_queue_item_id: int,
|
||||
session_queue_batch_id: str,
|
||||
graph_execution_state: GraphExecutionState,
|
||||
invoke_all: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
@@ -32,7 +39,9 @@ class Invoker:
|
||||
# Queue the invocation
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
session_queue_id=session_queue_id,
|
||||
session_queue_item_id=session_queue_item_id,
|
||||
session_queue_batch_id=session_queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
invoke_all=invoke_all,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -11,6 +11,13 @@ import torch
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
@@ -23,6 +30,22 @@ class LatentsStorageBase(ABC):
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
@@ -33,6 +56,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
@@ -50,11 +74,13 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
192
invokeai/app/services/model_convert.py
Normal file
192
invokeai/app/services/model_convert.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Copyright 2023 Lincoln Stein and the InvokeAI Team
|
||||
|
||||
"""
|
||||
Convert and merge models.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_install_service import ModelInstallServiceBase
|
||||
from .model_loader_service import ModelInfo, ModelLoadServiceBase
|
||||
from .model_record_service import ModelConfigBase, ModelRecordServiceBase, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelConvertBase(ABC):
|
||||
"""Convert and merge models."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
cls,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConvert(ModelConvertBase):
|
||||
"""Implementation of ModelConvertBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
self.loader = loader
|
||||
self.installer = installer
|
||||
self.store = store
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version as well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
new_diffusers_path = None
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
try:
|
||||
info: ModelConfigBase = self.store.get_model(key)
|
||||
|
||||
if info.model_format != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {info.name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = self.loader.get_model(key, **submodel)
|
||||
|
||||
checkpoint_path = config.models_path / info.path
|
||||
old_diffusers_path = config.models_path / converted_model.location
|
||||
|
||||
# new values to write in
|
||||
update = info.dict()
|
||||
update.pop("config")
|
||||
update["model_format"] = "diffusers"
|
||||
update["path"] = str(converted_model.location)
|
||||
|
||||
if dest_directory:
|
||||
new_diffusers_path = Path(dest_directory) / info.name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
move(old_diffusers_path, new_diffusers_path)
|
||||
update["path"] = new_diffusers_path.as_posix()
|
||||
|
||||
self.store.update_model(key, update)
|
||||
result = self.installer.sync_model_path(key, ignore_hash_change=True)
|
||||
except Exception as excp:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
if new_diffusers_path:
|
||||
rmtree(new_diffusers_path)
|
||||
raise excp
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
merger = ModelMerger(self.store)
|
||||
try:
|
||||
if not merged_model_name:
|
||||
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
|
||||
raise Exception("not implemented")
|
||||
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
653
invokeai/app/services/model_install_service.py
Normal file
653
invokeai/app/services/model_install_service.py
Normal file
@@ -0,0 +1,653 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend import get_precision
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.download.model_queue import (
|
||||
HTTP_RE,
|
||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
||||
DownloadJobMetadataURL,
|
||||
DownloadJobRepoID,
|
||||
DownloadJobWithMetadata,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.models import InvalidModelException
|
||||
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .events import EventServiceBase
|
||||
|
||||
from .download_manager import (
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadQueueService,
|
||||
DownloadQueueServiceBase,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallJob(DownloadJobBase):
|
||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
probe_override: Optional[Dict[str, Any]] = Field(
|
||||
description="Keys in this dict will override like-named attributes in the automatic probe info",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallURLJob(DownloadJobMetadataURL, ModelInstallJob):
|
||||
"""Job for installing URLs."""
|
||||
|
||||
|
||||
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
|
||||
"""Job for installing repo ids."""
|
||||
|
||||
|
||||
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
|
||||
"""Job for installing local paths."""
|
||||
|
||||
|
||||
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
):
|
||||
"""
|
||||
Create ModelInstallService object.
|
||||
|
||||
:param config: Optional InvokeAIAppConfig. If None passed,
|
||||
uses the system-wide default app config.
|
||||
:param download: Optional DownloadQueueServiceBase object. If None passed,
|
||||
a default queue object will be created.
|
||||
:param store: Optional ModelConfigStore. If None passed,
|
||||
defaults to `configs/models.yaml`.
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
:param event_handlers: List of event handlers to pass to the queue object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> DownloadQueueServiceBase:
|
||||
"""Return the download queue used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the storage backend used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dict of attributes that will override probed values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
This involves moving the model from its current location into
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
|
||||
:returns id: The string ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
start: Optional[bool] = True,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Download and install the indicated model.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
thread, and the returned object is a
|
||||
invokeai.backend.model_manager.download.DownloadJobBase
|
||||
object which can be interrogated to get the status of
|
||||
the download and install process. Call our `wait_for_installs()`
|
||||
method to wait for all downloads and installations to complete.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
:param inplace: If True, local paths will not be moved into
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:param subfolder: When downloading HF repo_ids this can be used to
|
||||
specify a subfolder of the HF repository to download from.
|
||||
:param probe_override: Optional dict. Any fields in this dict
|
||||
will override corresponding probe fields. Use it to override
|
||||
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
||||
:param metadata: Use this to override the fields 'description`,
|
||||
`author`, `tags`, `source` and `license`.
|
||||
|
||||
:returns ModelInstallJob object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
"""
|
||||
Recursively scan directory for new models and register or install them.
|
||||
|
||||
:param scan_dir: Path to the directory to scan.
|
||||
:param install: Install if True, otherwise register in place.
|
||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Compute and return the fast hash of the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: FastHash of the model for use as an ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_logger: Logger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueServiceBase
|
||||
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
|
||||
_installed: Set[str] = Field(default=set)
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
|
||||
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
|
||||
_event_bus: Optional["EventServiceBase"] = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
self._store = store or ModelRecordServiceBase.open(self._app_config)
|
||||
self._logger = InvokeAILogger.get_logger(config=self._app_config)
|
||||
self._event_bus = event_bus
|
||||
self._precision = get_precision()
|
||||
self._handlers = event_handlers
|
||||
if self._event_bus:
|
||||
self._handlers.append(self._event_bus.emit_model_event)
|
||||
|
||||
self._download_queue = queue or DownloadQueueService(event_bus=event_bus)
|
||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
|
||||
"""Call automatically at process start."""
|
||||
self.sync_to_config()
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueServiceBase:
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
return self._app_config
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
start: Optional[bool] = True,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
queue = self._download_queue
|
||||
variant = variant or ("fp16" if self._precision == "float16" else None)
|
||||
|
||||
job = self._make_download_job(
|
||||
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
|
||||
)
|
||||
handler = (
|
||||
self._complete_registration_handler
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
if isinstance(job, ModelInstallJob):
|
||||
job.probe_override = probe_override
|
||||
if metadata and isinstance(job, DownloadJobWithMetadata):
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.submit_download_job(job, start=start)
|
||||
return job
|
||||
|
||||
def register_path(
|
||||
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_hash = self.hash(new_path)
|
||||
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
return self._register(
|
||||
new_path,
|
||||
info,
|
||||
)
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
self._store.del_model(key)
|
||||
|
||||
def delete(self, key: str): # noqa D102
|
||||
model = self._store.get_model(key)
|
||||
path = self._app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self._store.get_model(key)
|
||||
models_dir = self._app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
key: str = FastModelHash.hash(model_path)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self._app_config.models_path):
|
||||
model_path = model_path.relative_to(self._app_config.models_path)
|
||||
|
||||
registration_data = dict(
|
||||
path=model_path.as_posix(),
|
||||
name=model_path.name if model_path.is_dir() else model_path.stem,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_format=info.format,
|
||||
hash=key,
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main:
|
||||
if info.variant_type:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
"Configuration file for this checkpoint could not be determined"
|
||||
) from exc
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
return move(old_path, new_path)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
try:
|
||||
setattr(info, key, value) # skip validation errors
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
if isinstance(job, DownloadJobWithMetadata):
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.name = metadata.name or info.name
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = self._download_queue.list_jobs()
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self._store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self._app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.hash = self.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.hash != key:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
|
||||
self._store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _make_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
destdir = source
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir), event_handlers=self._handlers)
|
||||
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._app_config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
cls = ModelInstallJob
|
||||
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
source = match.group(1)
|
||||
subfolder = match.group(2) or subfolder
|
||||
kwargs = dict(variant=variant, subfolder=subfolder)
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(
|
||||
source=str(source),
|
||||
destination=Path(self._tmpdir.name),
|
||||
access_token=access_token,
|
||||
priority=priority,
|
||||
event_handlers=self._handlers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""Pause until all installation jobs have completed."""
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
return id_map
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._installed = set()
|
||||
search.search(scan_dir)
|
||||
return list(self._installed)
|
||||
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||
for model_config in self._store.all_models():
|
||||
path = Path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self.scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
self.scan_directory(self._app_config.root_path / autoimport)
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.register_path(model)
|
||||
self.sync_model_path(id) # possibly move it to right place in `models`
|
||||
self._logger.info(f"Registered {model.name} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _scan_install(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.install_path(model)
|
||||
self._logger.info(f"Installed {model} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
140
invokeai/app/services/model_loader_service.py
Normal file
140
invokeai/app/services/model_loader_service.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.loader import ModelConfigBase, ModelInfo, ModelLoad
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
"""Load models into memory."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
):
|
||||
"""
|
||||
Initialize a ModelLoadService
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: ModelConfigStore object for fetching configuration information
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model identified by key.
|
||||
|
||||
:param key: Unique key returned by the ModelConfigStore module.
|
||||
:param submodel_type: Submodel to return (required for main models)
|
||||
:param context" Optional InvocationContext, used in event reporting.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoad
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
record_store: Union[ModelConfigStore, ModelRecordServiceBase],
|
||||
):
|
||||
"""
|
||||
Initialize a ModelLoadService.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
self._loader = ModelLoad(config, record_store)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model.
|
||||
|
||||
The submodel is required when fetching a main model.
|
||||
"""
|
||||
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_key=key,
|
||||
submodel=submodel_type,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics. Is this used?
|
||||
"""
|
||||
self._loader.collect_cache_stats(cache_stats)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_key: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
)
|
||||
@@ -1,669 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelManager,
|
||||
ModelMerger,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels([directory], self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
)
|
||||
130
invokeai/app/services/model_record_service.py
Normal file
130
invokeai/app/services/model_record_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager import ( # noqa F401
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.storage import ( # noqa F401
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreSQL,
|
||||
ModelConfigStoreYAML,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ModelConfigStore):
|
||||
"""
|
||||
Responsible for managing model configuration records.
|
||||
|
||||
This is an ABC that is simply a subclassing of the ModelConfigStore ABC
|
||||
in the backend.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_db_file(cls, db_file: Path) -> ModelRecordServiceBase:
|
||||
"""
|
||||
Initialize a new object from a database file.
|
||||
|
||||
If the path does not exist, a new sqlite3 db will be initialized.
|
||||
|
||||
:param db_file: Path to the database file
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def open(
|
||||
cls, config: InvokeAIAppConfig, conn: Optional[sqlite3.Connection] = None, lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
"""
|
||||
Choose either a ModelConfigStoreSQL or a ModelConfigStoreFile backend.
|
||||
|
||||
Logic is as follows:
|
||||
1. if config.model_config_db contains a Path, then
|
||||
a. if the path looks like a .db file, open a new sqlite3 connection and return a ModelRecordServiceSQL
|
||||
b. if the path looks like a .yaml file, return a new ModelRecordServiceFile
|
||||
c. otherwise bail
|
||||
2. if config.model_config_db is the literal 'auto', then use the passed sqlite3 connection and thread lock.
|
||||
a. if either of these is missing, then we create our own connection to the invokeai.db file, which *should*
|
||||
be a safe thing to do - sqlite3 will use file-level locking.
|
||||
3. if config.model_config_db is None, then fall back to config.conf_path, using a yaml file
|
||||
"""
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = config.model_config_db
|
||||
if db is None:
|
||||
return ModelRecordServiceFile.from_db_file(config.model_conf_path)
|
||||
if str(db) == "auto":
|
||||
logger.info("Model config storage = main InvokeAI database")
|
||||
return (
|
||||
ModelRecordServiceSQL.from_connection(conn, lock)
|
||||
if (conn and lock)
|
||||
else ModelRecordServiceSQL.from_db_file(config.db_path)
|
||||
)
|
||||
assert isinstance(db, Path)
|
||||
suffix = db.suffix
|
||||
if suffix == ".yaml":
|
||||
logger.info(f"Model config storage = {str(db)}")
|
||||
return ModelRecordServiceFile.from_db_file(config.root_path / db)
|
||||
elif suffix == ".db":
|
||||
logger.info(f"Model config storage = {str(db)}")
|
||||
return ModelRecordServiceSQL.from_db_file(config.root_path / db)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unrecognized model config record db file type {db} in "model_config_db" configuration variable.'
|
||||
)
|
||||
|
||||
|
||||
class ModelRecordServiceSQL(ModelConfigStoreSQL):
|
||||
"""
|
||||
ModelRecordService that uses Sqlite for its backend.
|
||||
Please see invokeai/backend/model_manager/storage/sql.py for
|
||||
the implementation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_connection(cls, conn: sqlite3.Connection, lock: threading.Lock) -> ModelRecordServiceSQL:
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
This is the same as the default __init__() constructor.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
return cls(conn, lock)
|
||||
|
||||
@classmethod
|
||||
def from_db_file(cls, db_file: Path) -> ModelRecordServiceSQL: # noqa D102 - docstring in ABC
|
||||
Path(db_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_file, check_same_thread=False)
|
||||
lock = threading.Lock()
|
||||
return cls(conn, lock)
|
||||
|
||||
|
||||
class ModelRecordServiceFile(ModelConfigStoreYAML):
|
||||
"""
|
||||
ModelRecordService that uses a YAML file for its backend.
|
||||
|
||||
Please see invokeai/backend/model_manager/storage/yaml.py for
|
||||
the implementation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_db_file(cls, db_file: Path) -> ModelRecordServiceFile: # noqa D102 - docstring in ABC
|
||||
return cls(db_file)
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@@ -37,10 +38,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
queue_item = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||
|
||||
@@ -48,7 +50,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
try:
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
@@ -56,6 +57,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
self.__invoker.services.events.emit_session_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
@@ -67,6 +71,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
node_id=queue_item.invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
@@ -79,6 +86,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -87,15 +97,19 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
model_loader = self.__invoker.services.model_loader
|
||||
with statistics.collect_stats(invocation, graph_id, model_loader):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -111,6 +125,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -138,6 +155,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -155,10 +175,19 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.session_queue_batch_id,
|
||||
session_queue_item_id=queue_item.session_queue_item_id,
|
||||
session_queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -166,7 +195,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||
"""Enum for resource types."""
|
||||
@@ -25,6 +26,6 @@ class SimpleNameService(NameServiceBase):
|
||||
|
||||
# TODO: Add customizable naming schemes
|
||||
def create_image_name(self) -> str:
|
||||
uuid_str = str(uuid.uuid4())
|
||||
uuid_str = uuid_string()
|
||||
filename = f"{uuid_str}.png"
|
||||
return filename
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
class SessionProcessorBase(ABC):
|
||||
"""
|
||||
Base class for session processor.
|
||||
|
||||
The session processor is responsible for executing sessions. It runs a simple polling loop,
|
||||
checking the session queue for new sessions to execute. It must coordinate with the
|
||||
invocation queue to ensure only one session is executing at a time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
"""Starts or resumes the session processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause(self) -> SessionProcessorStatus:
|
||||
"""Pauses the session processor"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
"""Gets the status of the session processor"""
|
||||
pass
|
||||
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SessionProcessorStatus(BaseModel):
|
||||
is_started: bool = Field(description="Whether the session processor is started")
|
||||
is_processing: bool = Field(description="Whether a session is being processed")
|
||||
@@ -0,0 +1,140 @@
|
||||
import traceback
|
||||
from threading import BoundedSemaphore
|
||||
from threading import Event as ThreadEvent
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
POLLING_INTERVAL = 1
|
||||
THREAD_LIMIT = 1
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker: Invoker = invoker
|
||||
self.__queue_item: Optional[SessionQueueItem] = None
|
||||
|
||||
self.__resume_event = ThreadEvent()
|
||||
self.__stop_event = ThreadEvent()
|
||||
self.__poll_now_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(
|
||||
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self.__poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name in [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "session_canceled"
|
||||
and self.__queue_item is not None
|
||||
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
|
||||
):
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self.__resume_event.is_set():
|
||||
self.__resume_event.set()
|
||||
return self.get_status()
|
||||
|
||||
def pause(self) -> SessionProcessorStatus:
|
||||
if self.__resume_event.is_set():
|
||||
self.__resume_event.clear()
|
||||
return self.get_status()
|
||||
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self.__resume_event.is_set(),
|
||||
is_processing=self.__queue_item is not None,
|
||||
)
|
||||
|
||||
def __process(
|
||||
self,
|
||||
stop_event: ThreadEvent,
|
||||
poll_now_event: ThreadEvent,
|
||||
resume_event: ThreadEvent,
|
||||
):
|
||||
try:
|
||||
stop_event.clear()
|
||||
resume_event.set()
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[SessionQueueItem] = None
|
||||
self.__invoker.services.logger
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
try:
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is None and resume_event.is_set():
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.batch_id,
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
if queue_item is None:
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.session_queue.cancel_queue_item(
|
||||
queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self.__queue_item = None
|
||||
self.__threadLimit.release()
|
||||
112
invokeai/app/services/session_queue/session_queue_base.py
Normal file
112
invokeai/app/services/session_queue/session_queue_base.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
PruneResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
|
||||
class SessionQueueBase(ABC):
|
||||
"""Base class for session queue"""
|
||||
|
||||
@abstractmethod
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
"""Dequeues the next session queue item."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||
"""Enqueues a single graph for execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently-executing session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next session queue item (does not dequeue it)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
"""Deletes all session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
"""Deletes all completed and errored session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
"""Checks if the queue is empty"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
"""Checks if the queue is empty"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
"""Gets the status of the queue"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
"""Gets the status of a batch"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
||||
"""Cancels a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
"""Gets a page of session queue items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
pass
|
||||
423
invokeai/app/services/session_queue/session_queue_common.py
Normal file
423
invokeai/app/services/session_queue/session_queue_common.py
Normal file
@@ -0,0 +1,423 @@
|
||||
import datetime
|
||||
import json
|
||||
from itertools import chain, product
|
||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator, validator
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# region Errors
|
||||
|
||||
|
||||
class BatchZippedLengthError(ValueError):
|
||||
"""Raise when a batch has items of different lengths."""
|
||||
|
||||
|
||||
class BatchItemsTypeError(TypeError):
|
||||
"""Raise when a batch has items of different types."""
|
||||
|
||||
|
||||
class BatchDuplicateNodeFieldError(ValueError):
|
||||
"""Raise when a batch has duplicate node_path and field_name."""
|
||||
|
||||
|
||||
class TooManySessionsError(ValueError):
|
||||
"""Raise when too many sessions are requested."""
|
||||
|
||||
|
||||
class SessionQueueItemNotFoundError(ValueError):
|
||||
"""Raise when a queue item is not found."""
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
node_path: str = Field(description="The node into which this batch data item will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data item will be substituted.")
|
||||
value: BatchDataType = Field(description="The value to substitute into the node/field.")
|
||||
|
||||
|
||||
class BatchDatum(BaseModel):
|
||||
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||
items: list[BatchDataType] = Field(
|
||||
default_factory=list, description="The list of items to substitute into the node/field."
|
||||
)
|
||||
|
||||
|
||||
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
runs: int = Field(
|
||||
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||
)
|
||||
|
||||
@validator("data")
|
||||
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
|
||||
for i in batch_data_list:
|
||||
if len(i.items) != first_item_length:
|
||||
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
# Get the type of the first item in the list
|
||||
first_item_type = type(datum.items[0]) if datum.items else None
|
||||
for item in datum.items:
|
||||
if type(item) is not first_item_type:
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
||||
if v is None:
|
||||
return v
|
||||
paths: set[tuple[str, str]] = set()
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
pair = (datum.node_path, datum.field_name)
|
||||
if pair in paths:
|
||||
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
|
||||
paths.add(pair)
|
||||
return v
|
||||
|
||||
@root_validator(skip_on_failure=True)
|
||||
def validate_batch_nodes_and_edges(cls, values):
|
||||
batch_data_collection = cast(Optional[BatchDataCollection], values["data"])
|
||||
if batch_data_collection is None:
|
||||
return values
|
||||
graph = cast(Graph, values["graph"])
|
||||
for batch_data_list in batch_data_collection:
|
||||
for batch_data in batch_data_list:
|
||||
try:
|
||||
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||
except NodeNotFoundError:
|
||||
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||
if batch_data.field_name not in node.__fields__:
|
||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||
return values
|
||||
|
||||
@validator("graph")
|
||||
def validate_graph(cls, v: Graph):
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"graph",
|
||||
"runs",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# endregion Batch
|
||||
|
||||
|
||||
# region Queue Items
|
||||
|
||||
DEFAULT_QUEUE_ID = "default"
|
||||
|
||||
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
||||
|
||||
|
||||
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
||||
field_values_raw = queue_item_dict.get("field_values", None)
|
||||
return parse_raw_as(list[NodeFieldValue], field_values_raw) if field_values_raw is not None else None
|
||||
|
||||
|
||||
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
session_raw = queue_item_dict.get("session", "{}")
|
||||
return parse_raw_as(GraphExecutionState, session_raw)
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
item_id: int = Field(description="The identifier of the session queue item")
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
||||
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
||||
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
||||
field_values: Optional[list[NodeFieldValue]] = Field(
|
||||
default=None, description="The field values that were used for this queue item"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
return SessionQueueItemDTO(**queue_item_dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
"queue_id",
|
||||
"session_id",
|
||||
"priority",
|
||||
"session_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
|
||||
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||
return SessionQueueItem(**queue_item_dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"item_id",
|
||||
"status",
|
||||
"batch_id",
|
||||
"queue_id",
|
||||
"session_id",
|
||||
"session",
|
||||
"priority",
|
||||
"session_id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# endregion Queue Items
|
||||
|
||||
# region Query Results
|
||||
|
||||
|
||||
class SessionQueueStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
item_id: Optional[int] = Field(description="The current queue item id")
|
||||
batch_id: Optional[str] = Field(description="The current queue item's batch id")
|
||||
session_id: Optional[str] = Field(description="The current queue item's session id")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
|
||||
|
||||
class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
|
||||
|
||||
class EnqueueBatchResult(BaseModel):
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
priority: int = Field(description="The priority of the enqueued batch")
|
||||
|
||||
|
||||
class EnqueueGraphResult(BaseModel):
|
||||
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
priority: int = Field(description="The priority of the enqueued batch")
|
||||
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
|
||||
|
||||
|
||||
class ClearResult(BaseModel):
|
||||
"""Result of clearing the session queue"""
|
||||
|
||||
deleted: int = Field(..., description="Number of queue items deleted")
|
||||
|
||||
|
||||
class PruneResult(ClearResult):
|
||||
"""Result of pruning the session queue"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CancelByBatchIDsResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IsEmptyResult(BaseModel):
|
||||
"""Result of checking if the session queue is empty"""
|
||||
|
||||
is_empty: bool = Field(..., description="Whether the session queue is empty")
|
||||
|
||||
|
||||
class IsFullResult(BaseModel):
|
||||
"""Result of checking if the session queue is full"""
|
||||
|
||||
is_full: bool = Field(..., description="Whether the session queue is full")
|
||||
|
||||
|
||||
# endregion Query Results
|
||||
|
||||
|
||||
# region Util
|
||||
|
||||
|
||||
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
||||
"""
|
||||
Populates the given graph with the given batch data items.
|
||||
"""
|
||||
graph_clone = graph.copy(deep=True)
|
||||
for item in node_field_values:
|
||||
node = graph_clone.get_node(item.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
setattr(node, item.field_name, item.value)
|
||||
graph_clone.update_node(item.node_path, node)
|
||||
return graph_clone
|
||||
|
||||
|
||||
def create_session_nfv_tuples(
|
||||
batch: Batch, maximum: int
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
that was applied to the graph.
|
||||
"""
|
||||
|
||||
# TODO: Should this be a class method on Batch?
|
||||
|
||||
data: list[list[tuple[NodeFieldValue]]] = []
|
||||
batch_data_collection = batch.data if batch.data is not None else []
|
||||
for batch_datum_list in batch_data_collection:
|
||||
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
||||
|
||||
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
||||
for batch_datum in batch_datum_list:
|
||||
node_field_values = [
|
||||
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
||||
for item in batch_datum.items
|
||||
]
|
||||
node_field_values_to_zip.append(node_field_values)
|
||||
data.append(list(zip(*node_field_values_to_zip)))
|
||||
|
||||
# create generator to yield session,nfv tuples
|
||||
count = 0
|
||||
for _ in range(batch.runs):
|
||||
for d in product(*data):
|
||||
if count >= maximum:
|
||||
return
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values)
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring
|
||||
the overhead of actually generating them. Adapted from `create_sessions().
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
return batch.runs
|
||||
data = []
|
||||
for batch_datum_list in batch.data:
|
||||
to_zip = []
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
class SessionQueueValueToInsert(NamedTuple):
|
||||
"""A tuple of values to insert into the session_queue table"""
|
||||
|
||||
queue_id: str # queue_id
|
||||
session: str # session json
|
||||
session_id: str # session_id
|
||||
batch_id: str # batch_id
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
||||
values_to_insert: ValuesToInsert = []
|
||||
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
# sessions must have unique id
|
||||
session.id = uuid_string()
|
||||
values_to_insert.append(
|
||||
SessionQueueValueToInsert(
|
||||
queue_id, # queue_id
|
||||
session.json(), # session (json)
|
||||
session.id, # session_id
|
||||
batch.batch_id, # batch_id
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
json.dumps(field_values, default=pydantic_encoder) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
|
||||
# endregion Util
|
||||
835
invokeai/app/services/session_queue/session_queue_sqlite.py
Normal file
835
invokeai/app/services/session_queue/session_queue_sqlite.py
Normal file
@@ -0,0 +1,835 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
DEFAULT_QUEUE_ID,
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
PruneResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
__invoker: Invoker
|
||||
__conn: sqlite3.Connection
|
||||
__cursor: sqlite3.Cursor
|
||||
__lock: threading.Lock
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self.__conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self.__conn.row_factory = sqlite3.Row
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self.__lock = lock
|
||||
self._create_tables()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the session queue tables, indicies, and triggers"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
WHERE status = 'in_progress';
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, self.__cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
||||
|
||||
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND batch_id = ?
|
||||
""",
|
||||
(queue_id, enqueue_result.batch.batch_id),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
||||
return EnqueueGraphResult(
|
||||
**enqueue_result.dict(),
|
||||
queue_item=SessionQueueItemDTO.from_dict(dict(result)),
|
||||
)
|
||||
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = calc_session_count(batch)
|
||||
values_to_insert = prepare_values_to_insert(
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
if requested_count > enqueued_count:
|
||||
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.from_dict(dict(result))
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def _set_queue_item_status(
|
||||
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?, error = ?
|
||||
WHERE item_id = ?
|
||||
""",
|
||||
(status, error, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id=item_id)
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return queue_item
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
try:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
status = 'completed'
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
status = "failed" if error is not None else "canceled"
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error)
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND batch_id IN ({placeholders})
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id is ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
try:
|
||||
item_id = cursor
|
||||
self.__lock.acquire()
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
self.__cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
item_id=current_item.item_id if current_item else None,
|
||||
session_id=current_item.session_id if current_item else None,
|
||||
batch_id=current_item.batch_id if current_item else None,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
14
invokeai/app/services/shared/models.py
Normal file
14
invokeai/app/services/shared/models.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||
|
||||
|
||||
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||
"""Cursor-paginated results"""
|
||||
|
||||
limit: int = Field(..., description="Limit of items to get")
|
||||
has_more: bool = Field(..., description="Whether there are more items available")
|
||||
items: list[GenericBaseModel] = Field(..., description="Items")
|
||||
@@ -1,5 +1,5 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
import threading
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
@@ -12,23 +12,19 @@ sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._lock = lock
|
||||
self._conn = conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@@ -49,8 +45,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
||||
3
invokeai/app/services/thread.py
Normal file
3
invokeai/app/services/thread.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import threading
|
||||
|
||||
lock = threading.Lock()
|
||||
@@ -265,22 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
||||
|
||||
|
||||
def prepare_control_image(
|
||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||
image: Image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
|
||||
width=512, # should be 8 * latent.shape[3]
|
||||
height=512, # should be 8 * latent height[2]
|
||||
# batch_size=1, # currently no batching
|
||||
# num_images_per_prompt=1, # currently only single image
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
):
|
||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||
"""Pre-process images for ControlNets or T2I-Adapters.
|
||||
|
||||
Args:
|
||||
image (Image): The PIL image to pre-process.
|
||||
width (int): The target width in pixels.
|
||||
height (int): The target height in pixels.
|
||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||
Defaults to True.
|
||||
control_mode (str, optional): Defaults to "balanced".
|
||||
resize_mode (str, optional): Defaults to "just_resize_simple".
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If resize_mode == "crop_resize_simple".
|
||||
NotImplementedError: If resize_mode == "fill_resize_simple".
|
||||
ValueError: If `resize_mode` is not recognized.
|
||||
ValueError: If `num_channels` is out of range.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The pre-processed input tensor.
|
||||
"""
|
||||
if (
|
||||
resize_mode == "just_resize_simple"
|
||||
or resize_mode == "crop_resize_simple"
|
||||
@@ -289,10 +308,10 @@ def prepare_control_image(
|
||||
image = image.convert("RGB")
|
||||
if resize_mode == "just_resize_simple":
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
elif resize_mode == "crop_resize_simple": # not yet implemented
|
||||
pass
|
||||
elif resize_mode == "fill_resize_simple": # not yet implemented
|
||||
pass
|
||||
elif resize_mode == "crop_resize_simple":
|
||||
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
|
||||
elif resize_mode == "fill_resize_simple":
|
||||
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
|
||||
nimage = np.array(image)
|
||||
nimage = nimage[None, :]
|
||||
nimage = np.concatenate([nimage], axis=0)
|
||||
@@ -313,9 +332,11 @@ def prepare_control_image(
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
pass
|
||||
print("ERROR: invalid resize_mode ==> ", resize_mode)
|
||||
exit(1)
|
||||
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
|
||||
|
||||
if timage.shape[1] < num_channels or num_channels <= 0:
|
||||
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
|
||||
timage = timage[:, :num_channels, :, :]
|
||||
|
||||
timage = timage.to(device=device, dtype=dtype)
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -21,3 +22,8 @@ SEED_MAX = np.iinfo(np.uint32).max
|
||||
def get_random_seed():
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
@@ -4,7 +4,7 @@ from PIL import Image
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
@@ -110,6 +110,9 @@ def stable_diffusion_step_callback(
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
|
||||
BIN
invokeai/assets/fonts/inter/Inter-Regular.ttf
Executable file
BIN
invokeai/assets/fonts/inter/Inter-Regular.ttf
Executable file
Binary file not shown.
94
invokeai/assets/fonts/inter/LICENSE.txt
Executable file
94
invokeai/assets/fonts/inter/LICENSE.txt
Executable file
@@ -0,0 +1,94 @@
|
||||
Copyright (c) 2016-2020 The Inter Project Authors.
|
||||
"Inter" is trademark of Rasmus Andersson.
|
||||
https://github.com/rsms/inter
|
||||
|
||||
This Font Software is licensed under the SIL Open Font License, Version 1.1.
|
||||
This license is copied below, and is also available with a FAQ at:
|
||||
http://scripts.sil.org/OFL
|
||||
|
||||
-----------------------------------------------------------
|
||||
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
|
||||
-----------------------------------------------------------
|
||||
|
||||
PREAMBLE
|
||||
The goals of the Open Font License (OFL) are to stimulate worldwide
|
||||
development of collaborative font projects, to support the font creation
|
||||
efforts of academic and linguistic communities, and to provide a free and
|
||||
open framework in which fonts may be shared and improved in partnership
|
||||
with others.
|
||||
|
||||
The OFL allows the licensed fonts to be used, studied, modified and
|
||||
redistributed freely as long as they are not sold by themselves. The
|
||||
fonts, including any derivative works, can be bundled, embedded,
|
||||
redistributed and/or sold with any software provided that any reserved
|
||||
names are not used by derivative works. The fonts and derivatives,
|
||||
however, cannot be released under any other type of license. The
|
||||
requirement for fonts to remain under this license does not apply
|
||||
to any document created using the fonts or their derivatives.
|
||||
|
||||
DEFINITIONS
|
||||
"Font Software" refers to the set of files released by the Copyright
|
||||
Holder(s) under this license and clearly marked as such. This may
|
||||
include source files, build scripts and documentation.
|
||||
|
||||
"Reserved Font Name" refers to any names specified as such after the
|
||||
copyright statement(s).
|
||||
|
||||
"Original Version" refers to the collection of Font Software components as
|
||||
distributed by the Copyright Holder(s).
|
||||
|
||||
"Modified Version" refers to any derivative made by adding to, deleting,
|
||||
or substituting -- in part or in whole -- any of the components of the
|
||||
Original Version, by changing formats or by porting the Font Software to a
|
||||
new environment.
|
||||
|
||||
"Author" refers to any designer, engineer, programmer, technical
|
||||
writer or other person who contributed to the Font Software.
|
||||
|
||||
PERMISSION AND CONDITIONS
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of the Font Software, to use, study, copy, merge, embed, modify,
|
||||
redistribute, and sell modified and unmodified copies of the Font
|
||||
Software, subject to the following conditions:
|
||||
|
||||
1) Neither the Font Software nor any of its individual components,
|
||||
in Original or Modified Versions, may be sold by itself.
|
||||
|
||||
2) Original or Modified Versions of the Font Software may be bundled,
|
||||
redistributed and/or sold with any software, provided that each copy
|
||||
contains the above copyright notice and this license. These can be
|
||||
included either as stand-alone text files, human-readable headers or
|
||||
in the appropriate machine-readable metadata fields within text or
|
||||
binary files as long as those fields can be easily viewed by the user.
|
||||
|
||||
3) No Modified Version of the Font Software may use the Reserved Font
|
||||
Name(s) unless explicit written permission is granted by the corresponding
|
||||
Copyright Holder. This restriction only applies to the primary font name as
|
||||
presented to the users.
|
||||
|
||||
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
|
||||
Software shall not be used to promote, endorse or advertise any
|
||||
Modified Version, except to acknowledge the contribution(s) of the
|
||||
Copyright Holder(s) and the Author(s) or with their explicit written
|
||||
permission.
|
||||
|
||||
5) The Font Software, modified or unmodified, in part or in whole,
|
||||
must be distributed entirely under this license, and must not be
|
||||
distributed under any other license. The requirement for fonts to
|
||||
remain under this license does not apply to any document created
|
||||
using the Font Software.
|
||||
|
||||
TERMINATION
|
||||
This license becomes null and void if any of the above conditions are
|
||||
not met.
|
||||
|
||||
DISCLAIMER
|
||||
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
|
||||
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
|
||||
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
|
||||
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
@@ -1,5 +1,15 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
from .model_manager import ( # noqa F401
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelConfigStore,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
SubModelType,
|
||||
)
|
||||
from .util.devices import get_precision # noqa F401
|
||||
|
||||
46
invokeai/backend/image_util/invoke_metadata.py
Normal file
46
invokeai/backend/image_util/invoke_metadata.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""Very simple functions to fetch and print metadata from InvokeAI-generated images."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_invokeai_metadata(image_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve "invokeai_metadata" field from png image.
|
||||
|
||||
:param image_path: Path to the image to read metadata from.
|
||||
May raise:
|
||||
OSError -- image path not found
|
||||
KeyError -- image doesn't contain the metadata field
|
||||
"""
|
||||
image: Image = Image.open(image_path)
|
||||
return json.loads(image.text["invokeai_metadata"])
|
||||
|
||||
|
||||
def print_invokeai_metadata(image_path: Path):
|
||||
"""Pretty-print the metadata."""
|
||||
try:
|
||||
metadata = get_invokeai_metadata(image_path)
|
||||
print(f"{image_path}:\n{json.dumps(metadata, sort_keys=True, indent=4)}")
|
||||
except OSError:
|
||||
print(f"{image_path}:\nNo file found.")
|
||||
except KeyError:
|
||||
print(f"{image_path}:\nNo metadata found.")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the command-line utility."""
|
||||
image_paths = sys.argv[1:]
|
||||
if not image_paths:
|
||||
print(f"Usage: {Path(sys.argv[0]).name} image1 image2 image3 ...")
|
||||
print("\nPretty-print InvokeAI image metadata from the listed png files.")
|
||||
sys.exit(-1)
|
||||
for img in image_paths:
|
||||
print_invokeai_metadata(img)
|
||||
@@ -8,7 +8,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
try:
|
||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||
assert config.model_conf_path.parent.exists(), f"{config.model_conf_path.parent} not found"
|
||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||
if not config.ignore_missing_core_models:
|
||||
|
||||
196
invokeai/backend/install/install_helper.py
Normal file
196
invokeai/backend/install/install_helper.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService, ModelSourceMetadata
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
name: Optional[str] = None
|
||||
base_model: Optional[BaseModelType] = None
|
||||
model_type: Optional[ModelType] = None
|
||||
source: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
installed: bool = False
|
||||
default: bool = False
|
||||
requires: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
||||
remove_models: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TqdmProgress(object):
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self):
|
||||
self._bars = dict()
|
||||
self._last = dict()
|
||||
|
||||
def job_update(self, job: ModelInstallJob):
|
||||
if not isinstance(job, DownloadJobRemoteSource):
|
||||
return
|
||||
job_id = job.id
|
||||
if job.status == "running" and job.total_bytes > 0: # job starts running before total bytes known
|
||||
if job_id not in self._bars:
|
||||
dest = Path(job.destination).name
|
||||
self._bars[job_id] = tqdm(
|
||||
desc=dest,
|
||||
initial=0,
|
||||
total=job.total_bytes,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
)
|
||||
self._last[job_id] = 0
|
||||
self._bars[job_id].update(job.bytes - self._last[job_id])
|
||||
self._last[job_id] = job.bytes
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||
|
||||
all_models: Dict[str, UnifiedModelInfo] = dict()
|
||||
_installer: ModelInstallService
|
||||
_config: InvokeAIAppConfig
|
||||
_installed_models: List[str] = []
|
||||
_starter_models: List[str] = []
|
||||
_default_model: Optional[str] = None
|
||||
_initial_models: omegaconf.DictConfig
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig):
|
||||
self._config = config
|
||||
self._installer = ModelInstallService(config=config, event_handlers=[TqdmProgress().job_update])
|
||||
self._initial_models = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||
self._initialize_model_lists()
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallService:
|
||||
return self._installer
|
||||
|
||||
def _initialize_model_lists(self):
|
||||
"""
|
||||
Initialize our model slots.
|
||||
|
||||
Set up the following:
|
||||
installed_models -- list of installed model keys
|
||||
starter_models -- list of starter model keys from INITIAL_MODELS
|
||||
all_models -- dict of key => UnifiedModelInfo
|
||||
default_model -- key to default model
|
||||
"""
|
||||
# previously-installed models
|
||||
for model in self._installer.store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
|
||||
self.all_models[key] = info
|
||||
self._installed_models.append(key)
|
||||
|
||||
for key in self._initial_models.keys():
|
||||
if key in self.all_models:
|
||||
# we want to preserve the description
|
||||
description = self.all_models[key].description or self._initial_models[key].get("description")
|
||||
self.all_models[key].description = description
|
||||
else:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
model_type=model_type,
|
||||
base_model=base_model,
|
||||
source=self._initial_models[key].source,
|
||||
description=self._initial_models[key].get("description"),
|
||||
recommended=self._initial_models[key].get("recommended", False),
|
||||
default=self._initial_models[key].get("default", False),
|
||||
subfolder=self._initial_models[key].get("subfolder"),
|
||||
requires=list(self._initial_models[key].get("requires", [])),
|
||||
)
|
||||
self.all_models[key] = info
|
||||
if not self.default_model:
|
||||
self._default_model = key
|
||||
elif self._initial_models[key].get("default", False):
|
||||
self._default_model = key
|
||||
self._starter_models.append(key)
|
||||
|
||||
# previously-installed models
|
||||
for model in self._installer.store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
|
||||
self.all_models[key] = info
|
||||
self._installed_models.append(key)
|
||||
|
||||
def recommended_models(self) -> List[UnifiedModelInfo]:
|
||||
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
||||
|
||||
def installed_models(self) -> List[UnifiedModelInfo]:
|
||||
return [self._to_model(x) for x in self._installed_models]
|
||||
|
||||
def starter_models(self) -> List[UnifiedModelInfo]:
|
||||
return [self._to_model(x) for x in self._starter_models]
|
||||
|
||||
def default_model(self) -> UnifiedModelInfo:
|
||||
return self._to_model(self._default_model)
|
||||
|
||||
def _to_model(self, key: str) -> UnifiedModelInfo:
|
||||
return self.all_models[key]
|
||||
|
||||
def _add_required_models(self, model_list: List[UnifiedModelInfo]):
|
||||
installed = {x.source for x in self.installed_models()}
|
||||
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||
additional_models = []
|
||||
for model_info in model_list:
|
||||
for requirement in model_info.requires:
|
||||
if requirement not in installed:
|
||||
additional_models.append(reverse_source.get(requirement))
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections):
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
metadata = ModelSourceMetadata(description=model.description, name=model.name)
|
||||
installer.install_model(
|
||||
model.source,
|
||||
subfolder=model.subfolder,
|
||||
access_token=HfFolder.get_token(),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
for model in selections.remove_models:
|
||||
parts = model.split("/")
|
||||
if len(parts) == 1:
|
||||
base_model, model_type, model_name = (None, None, model)
|
||||
else:
|
||||
base_model, model_type, model_name = parts
|
||||
matches = installer.store.search_by_name(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name
|
||||
)
|
||||
if len(matches) > 1:
|
||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||
elif not matches:
|
||||
print(f"{model}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
print(f"Deleting {m.model_type}:{m.name}")
|
||||
installer.conditionally_delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
||||
@@ -22,7 +22,6 @@ from typing import Any, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
@@ -38,21 +37,25 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.storage import ConfigFileVersionMismatchException, migrate_models_store
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.model_install import addModelsForm
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
CenteredButtonPress,
|
||||
CheckboxWithChanged,
|
||||
CyclingForm,
|
||||
FileBox,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumnsSimple,
|
||||
SingleSelectWithChanged,
|
||||
WindowTooSmallException,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
@@ -70,7 +73,6 @@ def get_literal_fields(field) -> list[Any]:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
@@ -83,7 +85,6 @@ GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
|
||||
@@ -93,10 +94,12 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
"""Dummy widget values."""
|
||||
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
@@ -180,6 +183,22 @@ class ProgressBar:
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage()
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
@@ -458,7 +477,26 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.disk = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.disk, range=(0, 100), step=0.5),
|
||||
out_of=100,
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
@@ -496,6 +534,45 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Location of the database used to store model path and configuration information:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely += 1
|
||||
if first_time:
|
||||
old_opts.model_config_db = "auto"
|
||||
self.model_conf_auto = self.add_widget_intelligent(
|
||||
CheckboxWithChanged,
|
||||
value=str(old_opts.model_config_db) == "auto",
|
||||
name="Main database",
|
||||
relx=2,
|
||||
max_width=25,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
config_db = str(old_opts.model_config_db or old_opts.conf_path)
|
||||
self.model_conf_override = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
value=str(old_opts.root_path / config_db)
|
||||
if config_db != "auto"
|
||||
else str(old_opts.root_path / old_opts.conf_path),
|
||||
name="Specify models config database manually",
|
||||
select_dir=False,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
# begin_entry_at=40,
|
||||
relx=30,
|
||||
max_height=3,
|
||||
max_width=100,
|
||||
scroll_exit=True,
|
||||
hidden=str(old_opts.model_config_db) == "auto",
|
||||
)
|
||||
self.model_conf_auto.on_changed = self.show_hide_model_conf_override
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@@ -507,19 +584,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -556,6 +635,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def show_hide_model_conf_override(self, value):
|
||||
self.model_conf_override.hidden = value
|
||||
self.model_conf_override.display()
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
@@ -591,17 +674,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"disk",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
if not self.autoimport_dirs[attr].value:
|
||||
continue
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
if directory.is_relative_to(config.root_path):
|
||||
directory = directory.relative_to(config.root_path)
|
||||
setattr(new_opts, attr, directory)
|
||||
|
||||
new_opts.model_config_db = "auto" if self.model_conf_auto.value else self.model_conf_override.value
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
@@ -616,13 +703,14 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.autoload_pending = True
|
||||
self.install_selections = default_user_selections(program_opts)
|
||||
self.install_helper = install_helper
|
||||
self.install_selections = default_user_selections(program_opts, install_helper)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@@ -645,32 +733,28 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
return self.options.marshall_arguments()
|
||||
|
||||
|
||||
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
|
||||
# So we adjust everthing down a bit.
|
||||
return (
|
||||
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts.ram = default_ramcache()
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
models = installer.all_models()
|
||||
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
|
||||
default_models = (
|
||||
[install_helper.default_model()] if program_opts.default_only else install_helper.recommended_models()
|
||||
)
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else list(),
|
||||
install_models=default_models if program_opts.yes_to_all else list(),
|
||||
)
|
||||
|
||||
|
||||
@@ -720,7 +804,7 @@ def maybe_create_models_yaml(root: Path):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path, install_helper: InstallHelper) -> (Namespace, Namespace):
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
@@ -729,13 +813,7 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
@@ -894,7 +972,8 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors = set()
|
||||
|
||||
@@ -907,14 +986,22 @@ def main():
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
# this will initialize the models.yaml file if not present
|
||||
try:
|
||||
install_helper = InstallHelper(config)
|
||||
except ConfigFileVersionMismatchException:
|
||||
config.model_config_db = migrate_models_store(config)
|
||||
install_helper = InstallHelper(config)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
@@ -929,10 +1016,12 @@ def main():
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
|
||||
elif models_to_download:
|
||||
process_and_execute(opt, models_to_download)
|
||||
install_helper.add_or_delete(models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
|
||||
@@ -3,13 +3,15 @@ Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
"""
|
||||
|
||||
#### NOTE: THIS SCRIPT NO LONGER WORKS WITH REFACTORED MODEL MANAGER, AND WILL NOT BE UPDATED.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import diffusers
|
||||
import transformers
|
||||
@@ -21,8 +23,9 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel,
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
@@ -43,19 +46,14 @@ class MigrateTo3(object):
|
||||
self,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
installer: ModelInstallService,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.installer = installer
|
||||
self.src_paths = src_paths
|
||||
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, "w") as file:
|
||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
def create_directory_structure(self):
|
||||
"""
|
||||
Create the basic directory structure for the models folder.
|
||||
@@ -107,44 +105,10 @@ class MigrateTo3(object):
|
||||
Recursively walk through src directory, probe anything
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
|
||||
This is now trivially easy using the installer service.
|
||||
"""
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root, d)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
directories_scanned.add(model)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
||||
continue
|
||||
model = Path(root, f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
self.installer.scan_directory(src_dir)
|
||||
|
||||
def migrate_support_models(self):
|
||||
"""
|
||||
@@ -260,23 +224,21 @@ class MigrateTo3(object):
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
||||
info = ModelProbe().heuristic_probe(vae)
|
||||
_, model_name = repo_id.split("/")
|
||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||
vae.save_pretrained(dest, safe_serialization=True)
|
||||
return dest
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Optional[Path]:
|
||||
self.installer.install(repo_id) # bug! We don't support subfolder yet.
|
||||
ids = self.installer.wait_for_installs()
|
||||
if key := ids.get(repo_id):
|
||||
return self.installer.store.get_model(key).path
|
||||
else:
|
||||
return None
|
||||
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
||||
"""
|
||||
Convert 2.3 VAE stanza to a straight path.
|
||||
"""
|
||||
vae_path = None
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Optional[Path]:
|
||||
"""Convert 2.3 VAE stanza to a straight path."""
|
||||
vae_path: Optional[Path] = None
|
||||
|
||||
# First get a path
|
||||
if isinstance(vae, str):
|
||||
vae_path = vae
|
||||
vae_path = Path(vae)
|
||||
|
||||
elif isinstance(vae, DictConfig):
|
||||
if p := vae.get("path"):
|
||||
@@ -284,28 +246,21 @@ class MigrateTo3(object):
|
||||
elif repo_id := vae.get("repo_id"):
|
||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||
return vae_path
|
||||
return Path(vae_path)
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||
|
||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||
if vae_path is None:
|
||||
return None
|
||||
|
||||
# if the VAE is in the old models directory, then we must move it into the new
|
||||
# one. VAEs outside of this directory can stay where they are.
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.is_relative_to(self.src_paths.models):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path, dest)
|
||||
else:
|
||||
self.copy_file(vae_path, dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
rel_path = vae_path.relative_to(self.dest_models)
|
||||
return Path("models", rel_path)
|
||||
key = self.installer.install_path(vae_path) # this will move the model
|
||||
return self.installer.store.get_model(key).path
|
||||
elif vae_path.is_relative_to(self.dest_models):
|
||||
key = self.installer.register_path(vae_path) # this will keep the model in place
|
||||
return self.installer.store.get_model(key).path
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
@@ -501,44 +456,27 @@ def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
def do_migrate(config: InvokeAIAppConfig, src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
||||
dest_models = dest_directory / "models.3"
|
||||
mm_store = ModelRecordServiceBase.open(config)
|
||||
mm_install = ModelInstallService(config=config, store=mm_store)
|
||||
|
||||
version_3 = (dest_directory / "models" / "core").exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
# file already exists, then we write into a copy of it to
|
||||
# avoid deleting its previous customizations. Otherwise we
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except Exception:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
if not version_3:
|
||||
src_directory = (dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, installer=mm_install, src_paths=paths)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||
|
||||
config_file.replace(config_file.with_suffix(""))
|
||||
dest_models.replace(dest_models.with_suffix(""))
|
||||
|
||||
|
||||
@@ -588,7 +526,7 @@ script, which will perform a full upgrade in place.""",
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root, dest_root)
|
||||
do_migrate(config, src_root, dest_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user