mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 17:48:13 -05:00
Compare commits
452 Commits
feat/ui/me
...
feat/laten
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b897ca18ce | ||
|
|
cc21fb216c | ||
|
|
6fe62a2705 | ||
|
|
da87378713 | ||
|
|
b6f5267385 | ||
|
|
f9e78d3c64 | ||
|
|
b7b5bd1b46 | ||
|
|
9a3727d3ad | ||
|
|
d68c14516c | ||
|
|
9f4d39aa42 | ||
|
|
84b801d88f | ||
|
|
2fc70c509b | ||
|
|
34fb1c4b19 | ||
|
|
80bdd550cf | ||
|
|
2359b92b46 | ||
|
|
a404fb2d32 | ||
|
|
513eb11616 | ||
|
|
d2c9140e69 | ||
|
|
d95fe5925a | ||
|
|
835922ea8f | ||
|
|
e1e5266fc3 | ||
|
|
5e4457445f | ||
|
|
0221ca8f49 | ||
|
|
cf36e4029e | ||
|
|
c8a98a9a22 | ||
|
|
38ecca9362 | ||
|
|
c4681774a5 | ||
|
|
050add58d2 | ||
|
|
3d60c958c7 | ||
|
|
f5df150097 | ||
|
|
dac82adb5b | ||
|
|
b72c9787a9 | ||
|
|
2623941d91 | ||
|
|
d3a7fea939 | ||
|
|
5a7b687c84 | ||
|
|
0020457fc7 | ||
|
|
658b556544 | ||
|
|
37da0fc075 | ||
|
|
6d3e8507cc | ||
|
|
0e9470503f | ||
|
|
d2ebc6741b | ||
|
|
026d3260b4 | ||
|
|
78533714e3 | ||
|
|
691e1bf829 | ||
|
|
47a088d685 | ||
|
|
63db3fc22f | ||
|
|
ad0bb3f61a | ||
|
|
8f8cd90787 | ||
|
|
d796ea7bec | ||
|
|
e5b7dd63e9 | ||
|
|
af060188bd | ||
|
|
4270e7ae25 | ||
|
|
60a565d7de | ||
|
|
78cf70eaad | ||
|
|
eebaa50710 | ||
|
|
7d582553f2 | ||
|
|
4d6eea7e81 | ||
|
|
f44593331d | ||
|
|
3d9ecbf3c7 | ||
|
|
032aa1d59c | ||
|
|
35e0863bdb | ||
|
|
14070d674e | ||
|
|
108ce06c62 | ||
|
|
da364f3444 | ||
|
|
df5ba75c14 | ||
|
|
e4fb9cb33f | ||
|
|
65b527eb20 | ||
|
|
7dc9d18052 | ||
|
|
5013a4b9f3 | ||
|
|
f929359322 | ||
|
|
6522c71971 | ||
|
|
9c1e65f3a3 | ||
|
|
ebec200ba6 | ||
|
|
e559730b6e | ||
|
|
0acb8ed85d | ||
|
|
8c1c9cd702 | ||
|
|
0ece4686aa | ||
|
|
af95cef7f9 | ||
|
|
1eca7a918a | ||
|
|
9e6b958023 | ||
|
|
f7b99d93ae | ||
|
|
85d03dcd90 | ||
|
|
032555bcfe | ||
|
|
4caa1f19b2 | ||
|
|
95d4bd3012 | ||
|
|
037078c8ad | ||
|
|
6de2f66b50 | ||
|
|
cd7b248eda | ||
|
|
6d8c077f4e | ||
|
|
97127e560e | ||
|
|
27dc07d95a | ||
|
|
f7dc171c4f | ||
|
|
4b957edfec | ||
|
|
46ca7718d9 | ||
|
|
b928d7a6e6 | ||
|
|
8a836247c8 | ||
|
|
95c3644564 | ||
|
|
799cd07174 | ||
|
|
9af385468d | ||
|
|
3487388788 | ||
|
|
9a383e456d | ||
|
|
805f9f8f4a | ||
|
|
52aa0c9bbd | ||
|
|
7f5f4689cc | ||
|
|
a3f81f4b98 | ||
|
|
15c59e606f | ||
|
|
40d4cabecd | ||
|
|
3493c8119b | ||
|
|
c1e7460d39 | ||
|
|
3ffff023b2 | ||
|
|
f9384be59b | ||
|
|
6cf308004a | ||
|
|
d1029138d2 | ||
|
|
06b5800d28 | ||
|
|
483f2ccb56 | ||
|
|
93ced0bec6 | ||
|
|
4333852c37 | ||
|
|
3baa230077 | ||
|
|
9e594f9018 | ||
|
|
b0c41b4828 | ||
|
|
e0d6946b6b | ||
|
|
bf7ea8309f | ||
|
|
54b65f725f | ||
|
|
8ef49c2640 | ||
|
|
f488b1a7f2 | ||
|
|
d2edb7c402 | ||
|
|
f0a3f07b45 | ||
|
|
b42b630583 | ||
|
|
31a78d571b | ||
|
|
fdc2232ea0 | ||
|
|
e94d0b2d40 | ||
|
|
75ccbaee9c | ||
|
|
2848c8397c | ||
|
|
fe8b5193de | ||
|
|
3d1470399c | ||
|
|
fcf9c63049 | ||
|
|
7bfb5640ad | ||
|
|
15e57e3a3d | ||
|
|
279468c0e8 | ||
|
|
c565812723 | ||
|
|
ec6c8e2a38 | ||
|
|
77f2690711 | ||
|
|
c4b3a24ed7 | ||
|
|
33c69359c2 | ||
|
|
864f4bb4af | ||
|
|
5365f42a04 | ||
|
|
3dc60254b9 | ||
|
|
027a8562d7 | ||
|
|
34f3a0f0e3 | ||
|
|
d0bac1675e | ||
|
|
4e56c962f4 | ||
|
|
4ef0e43759 | ||
|
|
6945d10297 | ||
|
|
4d6cef7ac8 | ||
|
|
a7786d5ff2 | ||
|
|
6c1de975d9 | ||
|
|
a1079e455a | ||
|
|
5457c7f069 | ||
|
|
b8c1a3f96c | ||
|
|
cee8e85f76 | ||
|
|
09f166577e | ||
|
|
bcc21531fb | ||
|
|
da4eacdffe | ||
|
|
6102e560ba | ||
|
|
ff3aa57117 | ||
|
|
49db6f4fac | ||
|
|
20f6a597ab | ||
|
|
04c453721c | ||
|
|
350ffecc1f | ||
|
|
b0557aa16b | ||
|
|
1c9429a6ea | ||
|
|
206e6b1730 | ||
|
|
357cee2849 | ||
|
|
0b49997bb6 | ||
|
|
5e09dd380d | ||
|
|
c7303adb0d | ||
|
|
ed1f096a6f | ||
|
|
6ab5d28cf3 | ||
|
|
a75148cb16 | ||
|
|
f7bbc4004a | ||
|
|
cee21ca082 | ||
|
|
08ec12b391 | ||
|
|
ff5e2a9a8c | ||
|
|
e0b9b5cc6c | ||
|
|
aca4770481 | ||
|
|
5d5157fc65 | ||
|
|
fb6ef61a4d | ||
|
|
ee24ad7b13 | ||
|
|
f8e90ba3f0 | ||
|
|
ad0b70ca23 | ||
|
|
7dfa135b2c | ||
|
|
beeaa05658 | ||
|
|
6b6d654f60 | ||
|
|
853c83d0c2 | ||
|
|
1809990ed4 | ||
|
|
79d49853d2 | ||
|
|
1f608d3743 | ||
|
|
df024dd982 | ||
|
|
45da85765c | ||
|
|
bd0ad59c27 | ||
|
|
cce40acba5 | ||
|
|
bc9491ab69 | ||
|
|
b909bac0dc | ||
|
|
8618e41b32 | ||
|
|
4687f94141 | ||
|
|
440912dcff | ||
|
|
8b87a26e7e | ||
|
|
44ae93df3e | ||
|
|
8f80ba9520 | ||
|
|
2b213da967 | ||
|
|
e91e1eb9aa | ||
|
|
b24129fb3e | ||
|
|
350b1421bb | ||
|
|
f01c79a94f | ||
|
|
463f6352ce | ||
|
|
a80fe05e23 | ||
|
|
58d7833c5c | ||
|
|
5012f61599 | ||
|
|
85c33823c3 | ||
|
|
c83a112669 | ||
|
|
e04ada1319 | ||
|
|
d866dcb3d2 | ||
|
|
81ec476f3a | ||
|
|
1e6adf0a06 | ||
|
|
7d221e2518 | ||
|
|
56d3cbead0 | ||
|
|
5e8c97f1ba | ||
|
|
4687ad4ed6 | ||
|
|
994b247f8e | ||
|
|
0419f50ab0 | ||
|
|
f9f40adcdc | ||
|
|
3264d30b44 | ||
|
|
4d885653e9 | ||
|
|
475b6bef53 | ||
|
|
d39de0ad38 | ||
|
|
d14a7d756e | ||
|
|
b050c1bb8f | ||
|
|
276dfc591b | ||
|
|
b49d76ebee | ||
|
|
a6be44789b | ||
|
|
a4313c26cb | ||
|
|
d4b250d509 | ||
|
|
29743a9e02 | ||
|
|
fecb77e344 | ||
|
|
779671753d | ||
|
|
d5e152b35e | ||
|
|
270657a62c | ||
|
|
3601b9c860 | ||
|
|
c8fe12cd91 | ||
|
|
deae5fbaec | ||
|
|
5b558af2b3 | ||
|
|
4150d5306f | ||
|
|
8c2e4700f9 | ||
|
|
adaecada20 | ||
|
|
258895bcc9 | ||
|
|
2eb7c25bae | ||
|
|
2e4e9434c1 | ||
|
|
0cad204e74 | ||
|
|
0bc2edc044 | ||
|
|
16488e7db8 | ||
|
|
974841926d | ||
|
|
8db20e0d95 | ||
|
|
d00d29d6b5 | ||
|
|
dc976cd665 | ||
|
|
6d6b986a66 | ||
|
|
bffdede0fa | ||
|
|
a4c258e9ec | ||
|
|
8d837558ac | ||
|
|
e673ed08ec | ||
|
|
f0e07bff5a | ||
|
|
3ec06a1fc3 | ||
|
|
6b79e2b407 | ||
|
|
0eed9dbc44 | ||
|
|
53c7832fd1 | ||
|
|
ca1cc0e2c2 | ||
|
|
5d8728c7ef | ||
|
|
a8cec4c7e6 | ||
|
|
2b5ccdc55f | ||
|
|
d92d5b5258 | ||
|
|
a591184d2a | ||
|
|
ee881e4c78 | ||
|
|
61fbb24e36 | ||
|
|
d582949488 | ||
|
|
de574eb4d9 | ||
|
|
bfd90968f1 | ||
|
|
4a924c9b54 | ||
|
|
0453d60c64 | ||
|
|
c4f4f8b1b8 | ||
|
|
3e80eaa342 | ||
|
|
00a0cb3403 | ||
|
|
ea93cad5ff | ||
|
|
4453a0d20d | ||
|
|
1e837e3c9d | ||
|
|
0f95f7cea3 | ||
|
|
0b0068ab86 | ||
|
|
31c7fa833e | ||
|
|
db16ca0079 | ||
|
|
a824f47bc6 | ||
|
|
99392debe8 | ||
|
|
0cc739afc8 | ||
|
|
0ab62b0343 | ||
|
|
75d25dd5cc | ||
|
|
2e54da13d8 | ||
|
|
f34f416bf5 | ||
|
|
021c63891d | ||
|
|
a968862e6b | ||
|
|
a08189d457 | ||
|
|
0a936696c3 | ||
|
|
55e33eaf4c | ||
|
|
3da5fb223f | ||
|
|
a3c5a664e5 | ||
|
|
b638fb2f30 | ||
|
|
c1b10b2222 | ||
|
|
bee29714d9 | ||
|
|
d40d5276dd | ||
|
|
568f0aad71 | ||
|
|
38474fa9d4 | ||
|
|
f7f974a28b | ||
|
|
3c150b384c | ||
|
|
65816049ba | ||
|
|
c1c881ded5 | ||
|
|
82c4dd8b86 | ||
|
|
711d09a107 | ||
|
|
74013b6611 | ||
|
|
790f399986 | ||
|
|
73cdd36594 | ||
|
|
50ac3eb28d | ||
|
|
d753cff91a | ||
|
|
89f1909e4b | ||
|
|
37916a22ad | ||
|
|
76e5d0595d | ||
|
|
f03cb8f134 | ||
|
|
c2a0e8afc3 | ||
|
|
31a904b903 | ||
|
|
c174cab3ee | ||
|
|
fe12938c23 | ||
|
|
4fa5c963a1 | ||
|
|
48ce256ba2 | ||
|
|
8cb2fa8600 | ||
|
|
8f460b92f1 | ||
|
|
d99a08a441 | ||
|
|
7555b1f876 | ||
|
|
a537231f19 | ||
|
|
8044d1b840 | ||
|
|
2b58ce4ae4 | ||
|
|
ef605cd76c | ||
|
|
a84b5b168f | ||
|
|
16f6ee04d0 | ||
|
|
44be057aa3 | ||
|
|
422f6967b2 | ||
|
|
4528cc8ba6 | ||
|
|
87e91ebc1d | ||
|
|
fd00d111ea | ||
|
|
b8dc9000bd | ||
|
|
58c1066765 | ||
|
|
37096a697b | ||
|
|
17d0920186 | ||
|
|
1e05538364 | ||
|
|
cf28617cd6 | ||
|
|
d0d8640711 | ||
|
|
e6158d1874 | ||
|
|
2e9d1ea8a3 | ||
|
|
59b0153236 | ||
|
|
9f8ff912c4 | ||
|
|
f0e4a2124a | ||
|
|
11ab5c7d56 | ||
|
|
3f334d9e5e | ||
|
|
ff891b1ff2 | ||
|
|
2914ee10b0 | ||
|
|
e29c2fb782 | ||
|
|
b763f1809e | ||
|
|
d26b44104a | ||
|
|
b73fd2a6d2 | ||
|
|
f258aba6d1 | ||
|
|
2e70848aa0 | ||
|
|
e973aeef0d | ||
|
|
50e1ac731d | ||
|
|
43addc1548 | ||
|
|
4901911c1a | ||
|
|
44a653925a | ||
|
|
94a07a8da7 | ||
|
|
ad41afe65e | ||
|
|
77fa7519c4 | ||
|
|
6e29148d4d | ||
|
|
3044f3bfe5 | ||
|
|
67a8627cf6 | ||
|
|
3fb433cb91 | ||
|
|
5f498e10bd | ||
|
|
fdad62e88b | ||
|
|
955c81acef | ||
|
|
e1058f3416 | ||
|
|
edf16a253d | ||
|
|
46f5ef4100 | ||
|
|
b843255236 | ||
|
|
3a968e5072 | ||
|
|
b164330e3c | ||
|
|
69433c9f68 | ||
|
|
bd8ffd36bf | ||
|
|
fd80e84ea6 | ||
|
|
4824237a98 | ||
|
|
2c9a05eb59 | ||
|
|
ecb5bdaf7e | ||
|
|
2feeb1f44c | ||
|
|
554f353773 | ||
|
|
f6cdff2c5b | ||
|
|
aee27e94c9 | ||
|
|
695893e1ac | ||
|
|
b800a8eb2e | ||
|
|
9749ef34b5 | ||
|
|
9a43362127 | ||
|
|
866024ea6c | ||
|
|
601cc1f92c | ||
|
|
d6a9a4464d | ||
|
|
dac271725a | ||
|
|
e1fbecfcf7 | ||
|
|
63d10027a4 | ||
|
|
ef0773b8a3 | ||
|
|
3daaddf15b | ||
|
|
570c3fe690 | ||
|
|
cbd1a7263a | ||
|
|
7fc5fbd4ce | ||
|
|
6f6de402ad | ||
|
|
2ec4f5af10 | ||
|
|
281662a6e1 | ||
|
|
2edd032ec7 | ||
|
|
50eb02f68b | ||
|
|
d73f3adc43 | ||
|
|
116107f464 | ||
|
|
da44bb1707 | ||
|
|
f43aed677e | ||
|
|
0d051aaae2 | ||
|
|
e4e48ff995 | ||
|
|
442a6bffa4 | ||
|
|
aab262d991 | ||
|
|
47b9910b48 | ||
|
|
0b0e6fe448 | ||
|
|
23d65e7162 | ||
|
|
024fd54d0b | ||
|
|
c44c19e911 | ||
|
|
c132dbdefa | ||
|
|
f3081e7013 | ||
|
|
f904f14f9e | ||
|
|
8917a6d99b | ||
|
|
5a4765046e | ||
|
|
cee159dfa3 | ||
|
|
cd1b350dae | ||
|
|
8334757af9 | ||
|
|
bc2b9500e3 | ||
|
|
32857d81c5 | ||
|
|
28f75d80d5 | ||
|
|
b917ffa4d7 | ||
|
|
f682fb8040 |
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@@ -1,16 +1,16 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
/docs/ @lstein @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
@@ -22,11 +22,11 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
|
||||
15
.github/workflows/mkdocs-material.yml
vendored
15
.github/workflows/mkdocs-material.yml
vendored
@@ -2,8 +2,7 @@ name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'refs/heads/v2.3'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -12,6 +11,10 @@ jobs:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
@@ -22,11 +25,15 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install -r docs/requirements-mkdocs.txt
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
@@ -36,7 +43,7 @@ jobs:
|
||||
--verbose
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
invokeai.init
|
||||
.version
|
||||
.last_model
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
|
||||
@@ -33,6 +33,8 @@
|
||||
|
||||
</div>
|
||||
|
||||
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
|
||||
|
||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||
@@ -84,7 +86,7 @@ installing lots of models.
|
||||
|
||||
6. Wait while the installer does its thing. After installing the software,
|
||||
the installer will launch a script that lets you configure InvokeAI and
|
||||
select a set of starting image generaiton models.
|
||||
select a set of starting image generation models.
|
||||
|
||||
7. Find the folder that InvokeAI was installed into (it is not the
|
||||
same as the unpacked zip file directory!) The default location of this
|
||||
@@ -148,6 +150,11 @@ not supported.
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
```terminal
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
|
||||
```sh
|
||||
|
||||
@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
||||
At installation time, InvokeAI will ask whether the checker should be
|
||||
activated by default (neither argument given on the command line). The
|
||||
response is stored in the InvokeAI initialization file (usually
|
||||
`.invokeai` in your home directory). You can change the default at any
|
||||
`invokeai.init` in your home directory). You can change the default at any
|
||||
time by opening this file in a text editor and commenting or
|
||||
uncommenting the line `--nsfw_checker`.
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ experimental versions later.
|
||||
sudo apt update
|
||||
sudo apt install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt install python3.10 python3-pip python3.10-venv
|
||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||
```
|
||||
|
||||
|
||||
@@ -247,8 +247,8 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from invokeai.app.models.image import ImageField
|
||||
from invokeai.app.services.outputs_sqlite import OutputsSqliteItemStorage
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@@ -40,15 +45,16 @@ class ApiDependencies:
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int):
|
||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
# TO DO: Use the config to select the logger rather than use the default
|
||||
# invokeai logging module
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
@@ -58,24 +64,34 @@ class ApiDependencies:
|
||||
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images')
|
||||
metadata = PngMetadataService()
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
model_manager=get_model_manager(config,logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
outputs=OutputsSqliteItemStorage(filename=db_location),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -45,7 +45,7 @@ class FastAPIEventService(EventServiceBase):
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.001)
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||
|
||||
|
||||
class ImageResponseMetadata(BaseModel):
|
||||
"""An image's metadata. Used only in HTTP responses."""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
@@ -11,4 +23,18 @@ class ImageResponse(BaseModel):
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageMetadata = Field(description="The image's metadata")
|
||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class SavedImage(BaseModel):
|
||||
image_name: str = Field(description="The name of the saved image")
|
||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
||||
created: int = Field(description="The created timestamp of the saved image")
|
||||
|
||||
@@ -3,14 +3,17 @@ import io
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from fastapi import Path, Query, Request, UploadFile
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.api.models.images import (
|
||||
ImageResponse,
|
||||
ImageResponseMetadata,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
@@ -18,85 +21,128 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
"""Gets a result"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
) -> FileResponse:
|
||||
"""Gets an image"""
|
||||
|
||||
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_type: ImageType = Path(description="The type of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image and its thumbnail"""
|
||||
|
||||
ApiDependencies.invoker.services.images.delete(
|
||||
image_type=image_type, image_name=image_name
|
||||
)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
|
||||
)
|
||||
async def get_thumbnail(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
|
||||
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
|
||||
) -> FileResponse | Response:
|
||||
"""Gets a thumbnail"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
|
||||
)
|
||||
|
||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||
return FileResponse(path)
|
||||
else:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
operation_id="upload_image",
|
||||
responses={
|
||||
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
|
||||
404: {"description": "Session not found"},
|
||||
201: {
|
||||
"description": "The image was uploaded successfully",
|
||||
"model": ImageResponse,
|
||||
},
|
||||
415: {"description": "Image upload failed"},
|
||||
},
|
||||
status_code=201
|
||||
status_code=201,
|
||||
)
|
||||
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
|
||||
async def upload_image(
|
||||
file: UploadFile, image_type: ImageType, request: Request, response: Response
|
||||
) -> ImageResponse:
|
||||
if not file.content_type.startswith("image"):
|
||||
return Response(status_code=415)
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await file.read()
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(contents))
|
||||
except:
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
|
||||
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
|
||||
|
||||
saved_image = ApiDependencies.invoker.services.images.save(
|
||||
image_type, filename, img
|
||||
)
|
||||
|
||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||
|
||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
image_type, saved_image.image_name
|
||||
)
|
||||
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
image_type, saved_image.image_name, True
|
||||
)
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
created=int(os.path.getctime(image_path)),
|
||||
image_type=image_type,
|
||||
image_name=saved_image.image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
metadata=ImageResponseMetadata(
|
||||
created=saved_image.created,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
response.headers["Location"] = request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
)
|
||||
response.headers["Location"] = image_url
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
||||
image_type: ImageType = Query(
|
||||
default=ImageType.RESULT, description="The type of images to get"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(
|
||||
image_type, page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
||||
return result
|
||||
|
||||
@@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@@ -112,19 +108,20 @@ async def update_model(
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
print(f">> Checking for model {model_name}...")
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
print(f">> Model not found")
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
@@ -248,4 +245,4 @@ async def delete_model(model_name: str) -> None:
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
# except Exception as e:
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.responses import Response
|
||||
from fastapi import Body, HTTPException, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.fields import Field
|
||||
|
||||
@@ -76,7 +75,7 @@ async def get_session(
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
@@ -99,7 +98,7 @@ async def add_node(
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_node(node)
|
||||
@@ -108,9 +107,9 @@ async def add_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session.id
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
@@ -132,7 +131,7 @@ async def update_node(
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.update_node(node_path, node)
|
||||
@@ -141,9 +140,9 @@ async def update_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
@@ -162,7 +161,7 @@ async def delete_node(
|
||||
"""Deletes a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.delete_node(node_path)
|
||||
@@ -171,9 +170,9 @@ async def delete_node(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.post(
|
||||
@@ -192,7 +191,7 @@ async def add_edge(
|
||||
"""Adds an edge to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
session.add_edge(edge)
|
||||
@@ -201,9 +200,9 @@ async def add_edge(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||
@@ -226,7 +225,7 @@ async def delete_edge(
|
||||
"""Deletes an edge from the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
try:
|
||||
edge = Edge(
|
||||
@@ -239,9 +238,9 @@ async def delete_edge(
|
||||
) # TODO: can this be done automatically, or add node through an API?
|
||||
return session
|
||||
except NodeAlreadyExecutedError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
except IndexError:
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
@session_router.put(
|
||||
@@ -259,14 +258,14 @@ async def invoke_session(
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
) -> None:
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
return Response(status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
if session.is_complete():
|
||||
return Response(status_code=400)
|
||||
raise HTTPException(status_code=400)
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
@@ -281,7 +280,7 @@ async def invoke_session(
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> None:
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
import invokeai.backend.util.logging as logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@@ -16,7 +17,6 @@ from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
# Create the app
|
||||
@@ -56,7 +56,7 @@ async def startup_event():
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id
|
||||
config=config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
@@ -126,7 +126,6 @@ app.openapi = custom_openapi
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
@@ -144,6 +143,8 @@ def overridden_redoc():
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
|
||||
@@ -2,16 +2,46 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..models.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
@@ -36,30 +66,26 @@ def add_parsers(
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default if field.default_factory is None else field.default_factory(),
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
@@ -67,17 +93,38 @@ class CliContext:
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
@@ -183,7 +230,7 @@ class HistoryCommand(BaseCommand):
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
|
||||
@@ -10,6 +10,7 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager, Globals
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
@@ -160,8 +161,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
print(
|
||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
@@ -13,17 +13,21 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@@ -58,7 +62,7 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -76,20 +80,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@@ -98,14 +154,14 @@ def generate_matching_edges(
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for field in matching_fields
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@@ -125,7 +181,7 @@ def invoke_all(context: CliContext):
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
@@ -135,16 +191,18 @@ def invoke_all(context: CliContext):
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config)
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
completer = set_autocompleter(model_manager)
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
@@ -156,18 +214,26 @@ def invoke_cli():
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||
metadata=metadata,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
@@ -185,11 +251,12 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@@ -205,8 +272,24 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@@ -217,6 +300,7 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@@ -229,7 +313,7 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
@@ -242,7 +326,7 @@ def invoke_cli():
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@@ -256,12 +340,14 @@ def invoke_cli():
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -270,22 +356,22 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
@@ -95,7 +95,7 @@ class UIConfig(TypedDict, total=False):
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
|
||||
title: str
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
@@ -33,7 +34,9 @@ class RangeInvocation(BaseInvocation):
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
return IntCollectionOutput(
|
||||
collection=list(range(self.start, self.stop, self.step))
|
||||
)
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
@@ -43,8 +46,19 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntCollectionOutput(
|
||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||
)
|
||||
|
||||
246
invokeai/app/invocations/compel.py
Normal file
246
invokeai/app/invocations/compel.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
class Config:
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
class CompelOutput(BaseInvocationOutput):
|
||||
"""Compel parser output"""
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Prompt (Compel)",
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||
|
||||
# TODO: load without model
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
pipeline = model["model"]
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
|
||||
# TODO: global? input?
|
||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
||||
#use_full_precision = False
|
||||
|
||||
# TODO: redo TI when separate model loding implemented
|
||||
#textual_inversion_manager = TextualInversionManager(
|
||||
# tokenizer=tokenizer,
|
||||
# text_encoder=text_encoder,
|
||||
# full_precision=use_full_precision,
|
||||
#)
|
||||
|
||||
def load_huggingface_concepts(concepts: list[str]):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
# TODO: support legacy blend?
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
@@ -56,9 +56,14 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_inpainted, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_inpainted,
|
||||
)
|
||||
)
|
||||
@@ -1,24 +1,26 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..models.exceptions import CanceledException
|
||||
from ..util.step_callback import diffusers_step_callback_adapter
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
@@ -45,42 +47,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model_name = model["model_name"]
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
@@ -97,21 +99,22 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
invocation = graph_execution_state.execution_graph.get_node(self.id)
|
||||
|
||||
metadata = {
|
||||
"session": context.graph_execution_state_id,
|
||||
"source_id": source_id,
|
||||
"invocation": invocation.dict()
|
||||
}
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, generate_output.image, metadata)
|
||||
context.services.images.save(
|
||||
image_type, image_name, generate_output.image, metadata
|
||||
)
|
||||
|
||||
context.services.outputs.set(image_name, context.graph_execution_state_id)
|
||||
|
||||
s = context.services.outputs.get(image_name)
|
||||
print(s)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=generate_output.image
|
||||
image=generate_output.image,
|
||||
)
|
||||
|
||||
|
||||
@@ -131,20 +134,17 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
@@ -154,21 +154,27 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model = model["model_name"]
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
@@ -183,11 +189,16 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
@@ -197,6 +208,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
@@ -205,20 +227,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
@@ -236,17 +255,22 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
self.model = model["model_name"]
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
@@ -261,9 +285,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, result_image, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=result_image
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
@@ -8,7 +8,6 @@ from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -39,22 +38,22 @@ class ImageOutput(BaseInvocationOutput):
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"image",
|
||||
"width",
|
||||
"height",
|
||||
]
|
||||
}
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
image_type: ImageType, image_name: str, image: Image.Image
|
||||
) -> ImageOutput:
|
||||
image_field = ImageField(image_name=image_name, image_type=image_type)
|
||||
|
||||
return ImageOutput(image=image_field, width=image.width, height=image.height)
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
image_field = ImageField(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
@@ -74,23 +73,24 @@ class MaskOutput(BaseInvocationOutput):
|
||||
}
|
||||
|
||||
|
||||
# # TODO: this isn't really necessary anymore
|
||||
# class LoadImageInvocation(BaseInvocation):
|
||||
# """Load an image from a filename and provide it as output."""
|
||||
# #fmt: off
|
||||
# type: Literal["load_image"] = "load_image"
|
||||
class LoadImageInvocation(BaseInvocation):
|
||||
"""Load an image and provide it as output."""
|
||||
|
||||
# # Inputs
|
||||
# image_type: ImageType = Field(description="The type of the image")
|
||||
# image_name: str = Field(description="The name of the image")
|
||||
# #fmt: on
|
||||
# fmt: off
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# return ImageOutput(
|
||||
# image_type=self.image_type,
|
||||
# image_name=self.image_name,
|
||||
# image=result_image
|
||||
# )
|
||||
# Inputs
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image_type, self.image_name)
|
||||
|
||||
return build_image_output(
|
||||
image_type=self.image_type,
|
||||
image_name=self.image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
@@ -145,9 +145,16 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_crop, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=image_crop
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image_crop,
|
||||
)
|
||||
|
||||
|
||||
@@ -196,9 +203,16 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, new_image, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=new_image
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=new_image,
|
||||
)
|
||||
|
||||
|
||||
@@ -226,7 +240,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image_mask, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
@@ -258,7 +277,12 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, blur_image, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=blur_image
|
||||
)
|
||||
@@ -290,7 +314,12 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, lerp_image, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=lerp_image
|
||||
)
|
||||
@@ -327,7 +356,12 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||
)
|
||||
|
||||
233
invokeai/app/invocations/infill.py
Normal file
233
invokeai/app/invocations/infill.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image)
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
@@ -1,24 +1,29 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
import random
|
||||
from typing import Literal, Optional, Union
|
||||
import einops
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput, build_image_output
|
||||
from .compel import ConditioningField
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
@@ -30,45 +35,61 @@ class LatentsField(BaseModel):
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["latents_name"]}
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latent_output"] = "latent_output"
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -105,9 +126,9 @@ class NoiseInvocation(BaseInvocation):
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
@@ -124,32 +145,26 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=name)
|
||||
)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from a prompt."""
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
@@ -165,22 +180,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
@@ -190,7 +198,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
@@ -206,8 +214,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
@@ -219,18 +229,22 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state)
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
@@ -247,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@@ -257,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
@@ -268,31 +284,27 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state)
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
@@ -308,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@@ -350,9 +360,123 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image, self.dict())
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type, image_name=image_name, image=image
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
@@ -73,3 +74,12 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
#fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
#fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))
|
||||
|
||||
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
@@ -1,10 +1,9 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
@@ -44,7 +43,12 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0], self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
|
||||
@@ -49,7 +47,12 @@ class UpscaleInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0], self.dict())
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
|
||||
@@ -3,12 +3,12 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
logger = model_manager.logger
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
print(
|
||||
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||
)
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -9,6 +9,14 @@ class ImageType(str, Enum):
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
def is_image_type(obj):
|
||||
try:
|
||||
ImageType(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
@@ -18,9 +26,14 @@ class ImageField(BaseModel):
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"image_type",
|
||||
"image_name",
|
||||
]
|
||||
}
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
from typing import Any, Optional, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InvokeAIMetadata(BaseModel):
|
||||
"""An image's InvokeAI-specific metadata"""
|
||||
|
||||
session: Optional[str] = Field(description="The session that generated this image")
|
||||
source_id: Optional[str] = Field(
|
||||
description="The source id of the invocation that generated this image"
|
||||
)
|
||||
# TODO: figure out metadata
|
||||
invocation: Optional[Dict[str, Any]] = Field(
|
||||
default={}, description="The prepared invocation that generated this image"
|
||||
)
|
||||
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's general metadata"""
|
||||
|
||||
created: int = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||
default={}, description="The image's InvokeAI-specific metadata"
|
||||
)
|
||||
64
invokeai/app/services/default_graphs.py
Normal file
64
invokeai/app/services/default_graphs.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.compel import CompelInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
@@ -1,10 +1,9 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Dict, TypedDict
|
||||
from typing import Any
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
ProgressImage = TypedDict(
|
||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||
)
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@@ -14,7 +13,8 @@ class EventServiceBase:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
@@ -25,8 +25,8 @@ class EventServiceBase:
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
invocation_dict: dict,
|
||||
source_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
step: int,
|
||||
total_steps: int,
|
||||
@@ -36,52 +36,60 @@ class EventServiceBase:
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
progress_image=progress_image,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(
|
||||
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
|
||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
invocation=invocation_dict,
|
||||
source_id=source_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -89,5 +97,7 @@ class EventServiceBase:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
@@ -17,7 +16,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -283,7 +281,8 @@ class Graph(BaseModel):
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self._validate_edge(edge)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
@@ -354,7 +353,7 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
@@ -362,54 +361,53 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@@ -733,7 +731,7 @@ class Graph(BaseModel):
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
@@ -750,9 +748,7 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(
|
||||
description="The id of the execution state", default_factory=uuid.uuid4
|
||||
)
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
@@ -858,7 +854,8 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
@@ -946,11 +943,11 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
@@ -962,7 +959,7 @@ class GraphExecutionState(BaseModel):
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
return iterators
|
||||
|
||||
@@ -1098,7 +1095,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
try:
|
||||
self.graph._validate_edge(edge)
|
||||
except InvalidEdgeError:
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
@@ -1144,4 +1143,52 @@ class GraphExecutionState(BaseModel):
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import json
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Dict, List
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from pydantic import BaseModel, Json
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from send2trash import send2trash
|
||||
from invokeai.app.api.models.images import (
|
||||
ImageResponse,
|
||||
ImageResponseMetadata,
|
||||
SavedImage,
|
||||
)
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.services.metadata import (
|
||||
InvokeAIMetadata,
|
||||
MetadataServiceBase,
|
||||
build_invokeai_metadata_pnginfo,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
@@ -27,12 +31,14 @@ class ImageStorageBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a paginated list of images."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@@ -40,35 +46,59 @@ class ImageStorageBase(ABC):
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_uri(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the external URI to an image or its thumbnail."""
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def validate_path(self, path: str) -> bool:
|
||||
"""Validates an image path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> SavedImage:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
def create_name(self, context_id: str, node_id: str) -> str:
|
||||
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
||||
"""Creates a unique contextual image filename."""
|
||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
||||
|
||||
|
||||
class DiskImageStorage(ImageStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: str
|
||||
__pngWriter: PngWriter
|
||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||
__cache: Dict[str, Image]
|
||||
__max_cache_size: int
|
||||
__metadata_service: MetadataServiceBase
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
||||
self.__output_folder = output_folder
|
||||
self.__pngWriter = PngWriter(output_folder)
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
self.__metadata_service = metadata_service
|
||||
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -101,21 +131,22 @@ class DiskImageStorage(ImageStorageBase):
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
|
||||
|
||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
||||
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
image_url=self.get_uri(image_type, filename),
|
||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||
metadata=ImageResponseMetadata(
|
||||
created=int(os.path.getctime(path)),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
invokeai=invokeai_metadata
|
||||
invokeai=invokeai_metadata,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -146,45 +177,89 @@ class DiskImageStorage(ImageStorageBase):
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", image_name
|
||||
self.__output_folder, image_type, "thumbnails", basename
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
|
||||
print(metadata)
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, metadata
|
||||
) # TODO: just pass full path to png writer
|
||||
save_thumbnail(
|
||||
image=image,
|
||||
filename=image_name,
|
||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||
)
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
return abspath
|
||||
|
||||
def get_uri(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if is_thumbnail:
|
||||
thumbnail_basename = get_thumbnail_name(basename)
|
||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
||||
else:
|
||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
||||
|
||||
return uri
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
os.stat(path)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_name: str,
|
||||
image: Image,
|
||||
metadata: InvokeAIMetadata | None = None,
|
||||
) -> SavedImage:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
||||
# TODO: Reading the image and then saving it strips the metadata...
|
||||
if metadata:
|
||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||
else:
|
||||
image.save(image_path) # this saved image has an empty info
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
self.__set_cache(image_path, image)
|
||||
return image_path
|
||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||
|
||||
return SavedImage(
|
||||
image_name=image_name,
|
||||
thumbnail_name=thumbnail_name,
|
||||
created=int(os.path.getctime(image_path)),
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
def __get_cache(self, image_name: str) -> Image | None:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: str, image: Image):
|
||||
@@ -195,4 +270,5 @@ class DiskImageStorage(ImageStorageBase):
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
del self.__cache[cache_id]
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
|
||||
@@ -1,30 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import types
|
||||
|
||||
from invokeai.app.services.outputs_session_storage import OutputsSessionStorageABC
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
@@ -14,11 +19,14 @@ class InvocationServices:
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
outputs: OutputsSessionStorageABC
|
||||
metadata: MetadataServiceBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
@@ -26,18 +34,26 @@ class InvocationServices:
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
outputs: OutputsSessionStorageABC,
|
||||
metadata: MetadataServiceBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.outputs = outputs
|
||||
self.metadata = metadata
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
|
||||
@@ -49,7 +49,7 @@ class Invoker:
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
@@ -71,18 +71,12 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
|
||||
105
invokeai/app/services/metadata.py
Normal file
105
invokeai/app/services/metadata.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, TypedDict
|
||||
from PIL import Image, PngImagePlugin
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.models.image import ImageType, is_image_type
|
||||
|
||||
|
||||
class MetadataImageField(TypedDict):
|
||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||
|
||||
image_type: ImageType
|
||||
image_name: str
|
||||
|
||||
|
||||
class MetadataLatentsField(TypedDict):
|
||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||
|
||||
latents_name: str
|
||||
|
||||
|
||||
class MetadataColorField(TypedDict):
|
||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
a: int
|
||||
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
|
||||
]
|
||||
|
||||
|
||||
class InvokeAIMetadata(TypedDict, total=False):
|
||||
"""InvokeAI-specific metadata format."""
|
||||
|
||||
session_id: Optional[str]
|
||||
node: Optional[NodeMetadata]
|
||||
|
||||
|
||||
def build_invokeai_metadata_pnginfo(
|
||||
metadata: InvokeAIMetadata | None,
|
||||
) -> PngImagePlugin.PngInfo:
|
||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||
|
||||
return pnginfo
|
||||
|
||||
|
||||
class MetadataServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_metadata(
|
||||
self, session_id: str, node: BaseModel
|
||||
) -> InvokeAIMetadata | None:
|
||||
"""Builds an InvokeAIMetadata object"""
|
||||
pass
|
||||
|
||||
|
||||
class PngMetadataService(MetadataServiceBase):
|
||||
"""Handles loading and building metadata for images."""
|
||||
|
||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Loads a specific info entry from a PIL Image."""
|
||||
|
||||
try:
|
||||
info = image.info.get("invokeai")
|
||||
|
||||
if type(info) is not str:
|
||||
return None
|
||||
|
||||
loaded_metadata = json.loads(info)
|
||||
|
||||
if type(loaded_metadata) is not dict:
|
||||
return None
|
||||
|
||||
if len(loaded_metadata.items()) == 0:
|
||||
return None
|
||||
|
||||
return loaded_metadata
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||
"""Retrieves an image's metadata as a dict"""
|
||||
loaded_metadata = self._load_metadata(image)
|
||||
|
||||
return loaded_metadata
|
||||
|
||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||
|
||||
return metadata
|
||||
@@ -5,6 +5,7 @@ from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from ...backend import ModelManager
|
||||
@@ -12,16 +13,16 @@ from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@@ -62,11 +63,12 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
@@ -76,18 +78,18 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
@@ -96,13 +98,12 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
|
||||
59
invokeai/app/services/outputs_session_storage.py
Normal file
59
invokeai/app/services/outputs_session_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
class PaginatedStringResults(GenericModel):
|
||||
"""Paginated results"""
|
||||
#fmt: off
|
||||
items: list[str] = Field(description="Session IDs")
|
||||
page: int = Field(description="Current Page")
|
||||
pages: int = Field(description="Total number of pages")
|
||||
per_page: int = Field(description="Number of items per page")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
#fmt: on
|
||||
|
||||
class OutputsSessionStorageABC(ABC):
|
||||
_on_changed_callbacks: list[Callable[[str], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = list()
|
||||
self._on_deleted_callbacks = list()
|
||||
|
||||
"""Base item storage class"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, output_id: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, output_id: str, session_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedStringResults:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, foreign_key_value: str) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(foreign_key_value)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
162
invokeai/app/services/outputs_sqlite.py
Normal file
162
invokeai/app/services/outputs_sqlite.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
from invokeai.app.services.outputs_session_storage import (
|
||||
OutputsSessionStorageABC,
|
||||
PaginatedStringResults,
|
||||
)
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class OutputsSqliteItemStorage(OutputsSessionStorageABC):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS outputs (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
session_id TEXT NOT NULL
|
||||
);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS outputs_id ON outputs(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def set(self, output_id: str, session_id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO outputs (id, session_id) VALUES (?, ?);""",
|
||||
(output_id, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(output_id)
|
||||
|
||||
def get(self, output_id: str) -> Union[str, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT graph_executions.item session
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
WHERE outputs.id = ?;
|
||||
""",
|
||||
(output_id,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
|
||||
def delete(self, output_id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM outputs WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(output_id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT graph_executions.item session
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: r[0], result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT count(*)
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id;
|
||||
""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedStringResults(
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedStringResults:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT graph_executions.item session
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
WHERE outputs.id LIKE ? LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = list(map(lambda r: r[0], result))
|
||||
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
SELECT count(*)
|
||||
FROM graph_executions
|
||||
INNER JOIN outputs ON outputs.session_id = graph_executions.id
|
||||
WHERE outputs.id LIKE ?;
|
||||
""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedStringResults(
|
||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
@@ -1,17 +1,22 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
@@ -20,7 +25,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: probably better to just not use threads?
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
@@ -29,9 +34,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
@@ -43,14 +55,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item.invocation_id
|
||||
)
|
||||
|
||||
# get the source node to provide to cliepnts (the prepared node is not as useful)
|
||||
source_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id
|
||||
)
|
||||
|
||||
# Invoke
|
||||
@@ -79,8 +91,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.dict(),
|
||||
)
|
||||
|
||||
@@ -104,13 +116,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=error,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
@@ -120,11 +132,22 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
@@ -10,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args):
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
@@ -20,20 +21,22 @@ class RestorationServices:
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
@@ -58,15 +61,15 @@ class RestorationServices:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@@ -80,7 +83,7 @@ class RestorationServices:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@@ -93,10 +96,10 @@ class RestorationServices:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
from .item_storage import ItemStorageABC, PaginatedResults
|
||||
|
||||
from sqlalchemy import create_engine, String, TEXT, Engine, select
|
||||
from sqlalchemy.orm import DeclarativeBase, mapped_column, Session
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_engine: Engine
|
||||
# _table: ??? # TODO: figure out how to type this
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
@@ -27,79 +25,86 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._engine = create_engine(f"sqlite+pysqlite:///{self._filename}")
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
# dynamically create the ORM model class to avoid name collisions
|
||||
|
||||
# cannot access `self.__orig_class__` in `__init__` or `__new__` so
|
||||
# format the table name into the class name
|
||||
pascal_table_name = self._table_name.replace("_", " ").title()
|
||||
pascal_table_name = pascal_table_name.replace(" ", "")
|
||||
|
||||
table_dict = dict(
|
||||
__tablename__=self._table_name,
|
||||
id=mapped_column(String, primary_key=True),
|
||||
item=mapped_column(TEXT, nullable=False),
|
||||
)
|
||||
|
||||
self._table = type(pascal_table_name, (Base,), table_dict)
|
||||
|
||||
Base.metadata.create_all(self._engine)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
session = Session(self._engine)
|
||||
|
||||
item_id = str(getattr(item, self._id_field))
|
||||
new_item = self._table(id=item_id, item=item.json())
|
||||
|
||||
session.merge(new_item)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
session = Session(self._engine)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
item = session.get(self._table, str(id))
|
||||
|
||||
session.close()
|
||||
|
||||
if not item:
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(item.item)
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
session = Session(self._engine)
|
||||
|
||||
item = session.get(self._table, id)
|
||||
session.delete(item)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
session = Session(self._engine)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
stmt = select(self._table.item).limit(per_page).offset(page * per_page)
|
||||
result = session.execute(stmt)
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
count = session.query(self._table.item).count()
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
@@ -110,19 +115,23 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
session = Session(self._engine)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
stmt = (
|
||||
session.query(self._table)
|
||||
.where(self._table.item.like(f"%{query}%"))
|
||||
.limit(per_page)
|
||||
.offset(page * per_page)
|
||||
)
|
||||
items = list(map(lambda r: self._parse_item(r[0]), result))
|
||||
|
||||
result = session.execute(stmt)
|
||||
|
||||
items = list(map(lambda r: self._parse_item(r[0].item), result))
|
||||
count = session.query(self._table.item).count()
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
|
||||
13
invokeai/app/util/misc.py
Normal file
13
invokeai/app/util/misc.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
@@ -1,25 +0,0 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def save_thumbnail(
|
||||
image: Image.Image,
|
||||
filename: str,
|
||||
path: str,
|
||||
size: int = 256,
|
||||
) -> str:
|
||||
"""
|
||||
Saves a thumbnail of an image, returning its path.
|
||||
"""
|
||||
base_filename = os.path.splitext(filename)[0]
|
||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
return thumbnail_path
|
||||
|
||||
image_copy = image.copy()
|
||||
image_copy.thumbnail(size=(size, size))
|
||||
|
||||
image_copy.save(thumbnail_path, "WEBP")
|
||||
|
||||
return thumbnail_path
|
||||
@@ -1,17 +1,41 @@
|
||||
from re import S
|
||||
import torch
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
def fast_latents_step_callback(
|
||||
sample: torch.Tensor,
|
||||
step: int,
|
||||
steps: int,
|
||||
id: str,
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
context: InvocationContext,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be. Use
|
||||
# that estimate if it is available.
|
||||
if intermediate_state.predicted_original is not None:
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
@@ -21,30 +45,11 @@ def fast_latents_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_id = graph_execution_state.prepared_source_mapping[id]
|
||||
|
||||
invocation = graph_execution_state.execution_graph.get_node(id)
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
invocation_dict=invocation.dict(),
|
||||
source_id=source_id,
|
||||
progress_image={"width": width, "height": height, "dataURL": dataURL},
|
||||
step=step,
|
||||
total_steps=steps,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
total_steps=node["steps"],
|
||||
)
|
||||
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||
This adapter grabs the needed data and passes it along to the callback function.
|
||||
"""
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return fast_latents_step_callback(
|
||||
progress_state.latents, progress_state.step, **kwargs
|
||||
)
|
||||
else:
|
||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
||||
|
||||
15
invokeai/app/util/thumbnails.py
Normal file
15
invokeai/app/util/thumbnails.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_thumbnail_name(image_name: str) -> str:
|
||||
"""Formats given an image name, returns the appropriate thumbnail image name"""
|
||||
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
||||
return thumbnail_name
|
||||
|
||||
|
||||
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
|
||||
"""Makes a thumbnail from a PIL Image"""
|
||||
thumbnail = image.copy()
|
||||
thumbnail.thumbnail(size=(size, size))
|
||||
return thumbnail
|
||||
@@ -10,7 +10,7 @@ from .generator import (
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
||||
@@ -96,6 +96,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import invokeai.version
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.image_util import retrieve_metadata
|
||||
|
||||
from .globals import Globals
|
||||
@@ -107,17 +108,21 @@ APP_VERSION = invokeai.version.__version__
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
@@ -189,7 +194,7 @@ class Args(object):
|
||||
print(f"{APP_NAME} {APP_VERSION}")
|
||||
sys.exit(0)
|
||||
|
||||
print("* Initializing, be patient...")
|
||||
logger.info("Initializing, be patient...")
|
||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||
Globals.try_patchmatch = switches.patchmatch
|
||||
|
||||
@@ -197,14 +202,13 @@ class Args(object):
|
||||
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
||||
legacyinit = os.path.expanduser("~/.invokeai")
|
||||
if os.path.exists(initfile):
|
||||
print(
|
||||
f">> Initialization file {initfile} found. Loading...",
|
||||
file=sys.stderr,
|
||||
logger.info(
|
||||
f"Initialization file {initfile} found. Loading...",
|
||||
)
|
||||
sysargs.insert(0, f"@{initfile}")
|
||||
elif os.path.exists(legacyinit):
|
||||
print(
|
||||
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||
logger.warning(
|
||||
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||
)
|
||||
sysargs.insert(0, f"@{legacyinit}")
|
||||
Globals.log_tokenization = self._arg_parser.parse_args(
|
||||
@@ -214,7 +218,7 @@ class Args(object):
|
||||
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
||||
return self._arg_switches
|
||||
except Exception as e:
|
||||
print(f"An exception has occurred: {e}")
|
||||
logger.error(f"An exception has occurred: {e}")
|
||||
return None
|
||||
|
||||
def parse_cmd(self, cmd_string):
|
||||
@@ -631,7 +635,7 @@ class Args(object):
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar="SAMPLER_NAME",
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default="k_lms",
|
||||
default="lms",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--log_tokenization",
|
||||
@@ -1154,7 +1158,7 @@ class Args(object):
|
||||
|
||||
|
||||
def format_metadata(**kwargs):
|
||||
print("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||
return metadata_dumps(kwargs)
|
||||
|
||||
|
||||
@@ -1326,7 +1330,7 @@ def metadata_loads(metadata) -> list:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
print(">> could not read metadata", file=sys.stderr)
|
||||
logger.error("Could not read metadata")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return results
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .args import metadata_from_png
|
||||
from .generator import infill_methods
|
||||
from .globals import Globals, global_cache_dir
|
||||
@@ -36,6 +37,7 @@ from .safety_checker import SafetyChecker
|
||||
from .prompting import get_uc_and_c_and_ec
|
||||
from .prompting.conditioning import log_tokenization
|
||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||
from .stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .util import choose_precision, choose_torch_device
|
||||
|
||||
def fix_func(orig):
|
||||
@@ -140,7 +142,7 @@ class Generate:
|
||||
model=None,
|
||||
conf="configs/models.yaml",
|
||||
embedding_path=None,
|
||||
sampler_name="k_lms",
|
||||
sampler_name="lms",
|
||||
ddim_eta=0.0, # deterministic
|
||||
full_precision=False,
|
||||
precision="auto",
|
||||
@@ -195,12 +197,12 @@ class Generate:
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
# it wasn't actually doing anything. This logic could be reinstated.
|
||||
self.device = torch.device(choose_torch_device())
|
||||
print(f">> Using device_type {self.device.type}")
|
||||
logger.info(f"Using device_type {self.device.type}")
|
||||
if full_precision:
|
||||
if self.precision != "auto":
|
||||
raise ValueError("Remove --full_precision / -F if using --precision")
|
||||
print("Please remove deprecated --full_precision / -F")
|
||||
print("If auto config does not work you can use --precision=float32")
|
||||
logger.warning("Please remove deprecated --full_precision / -F")
|
||||
logger.warning("If auto config does not work you can use --precision=float32")
|
||||
self.precision = "float32"
|
||||
if self.precision == "auto":
|
||||
self.precision = choose_precision(self.device)
|
||||
@@ -208,13 +210,13 @@ class Generate:
|
||||
|
||||
if is_xformers_available():
|
||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||
print(">> xformers memory-efficient attention is available and enabled")
|
||||
logger.info("xformers memory-efficient attention is available and enabled")
|
||||
else:
|
||||
print(
|
||||
">> xformers memory-efficient attention is available but disabled"
|
||||
logger.info(
|
||||
"xformers memory-efficient attention is available but disabled"
|
||||
)
|
||||
else:
|
||||
print(">> xformers not installed")
|
||||
logger.info("xformers not installed")
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_manager = ModelManager(
|
||||
@@ -229,8 +231,8 @@ class Generate:
|
||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||
model = model or fallback
|
||||
if not self.model_manager.valid_model(model):
|
||||
print(
|
||||
f'** "{model}" is not a known model name; falling back to {fallback}.'
|
||||
logger.warning(
|
||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
||||
)
|
||||
model = None
|
||||
self.model_name = model or fallback
|
||||
@@ -246,10 +248,10 @@ class Generate:
|
||||
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
print(">> Initializing NSFW checker")
|
||||
logger.info("Initializing NSFW checker")
|
||||
self.safety_checker = SafetyChecker(self.device)
|
||||
else:
|
||||
print(">> NSFW checker is disabled")
|
||||
logger.info("NSFW checker is disabled")
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
@@ -567,7 +569,7 @@ class Generate:
|
||||
self.clear_cuda_cache()
|
||||
|
||||
if catch_interrupts:
|
||||
print("**Interrupted** Partial results will be returned.")
|
||||
logger.warning("Interrupted** Partial results will be returned.")
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError:
|
||||
@@ -575,11 +577,11 @@ class Generate:
|
||||
self.clear_cuda_cache()
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> Could not generate image.")
|
||||
logger.info("Could not generate image.")
|
||||
|
||||
toc = time.time()
|
||||
print("\n>> Usage stats:")
|
||||
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
|
||||
logger.info("Usage stats:")
|
||||
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
||||
self.print_cuda_stats()
|
||||
return results
|
||||
|
||||
@@ -609,16 +611,16 @@ class Generate:
|
||||
def print_cuda_stats(self):
|
||||
if self._has_cuda():
|
||||
self.gather_cuda_stats()
|
||||
print(
|
||||
">> Max VRAM used for this generation:",
|
||||
"%4.2fG." % (self.max_memory_allocated / 1e9),
|
||||
"Current VRAM utilization:",
|
||||
"%4.2fG" % (self.memory_allocated / 1e9),
|
||||
logger.info(
|
||||
"Max VRAM used for this generation: "+
|
||||
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
||||
"Current VRAM utilization: "+
|
||||
"%4.2fG" % (self.memory_allocated / 1e9)
|
||||
)
|
||||
|
||||
print(
|
||||
">> Max VRAM used since script start: ",
|
||||
"%4.2fG" % (self.session_peakmem / 1e9),
|
||||
logger.info(
|
||||
"Max VRAM used since script start: " +
|
||||
"%4.2fG" % (self.session_peakmem / 1e9)
|
||||
)
|
||||
|
||||
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
||||
@@ -647,7 +649,7 @@ class Generate:
|
||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
|
||||
prompt = opt.prompt or args.prompt or ""
|
||||
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||
|
||||
# try to reuse the same filename prefix as the original file.
|
||||
# we take everything up to the first period
|
||||
@@ -696,8 +698,8 @@ class Generate:
|
||||
try:
|
||||
extend_instructions[direction] = int(pixels)
|
||||
except ValueError:
|
||||
print(
|
||||
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||
logger.warning(
|
||||
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||
)
|
||||
|
||||
opt.seed = seed
|
||||
@@ -720,8 +722,8 @@ class Generate:
|
||||
# fetch the metadata from the image
|
||||
generator = self.select_generator(embiggen=True)
|
||||
opt.strength = opt.embiggen_strength or 0.40
|
||||
print(
|
||||
f">> Setting img2img strength to {opt.strength} for happy embiggening"
|
||||
logger.info(
|
||||
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
||||
)
|
||||
generator.generate(
|
||||
prompt,
|
||||
@@ -748,12 +750,12 @@ class Generate:
|
||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||
|
||||
elif tool is None:
|
||||
print(
|
||||
"* please provide at least one postprocessing option, such as -G or -U"
|
||||
logger.warning(
|
||||
"please provide at least one postprocessing option, such as -G or -U"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
print(f"* postprocessing tool {tool} is not yet supported")
|
||||
logger.warning(f"postprocessing tool {tool} is not yet supported")
|
||||
return None
|
||||
|
||||
def select_generator(
|
||||
@@ -797,8 +799,8 @@ class Generate:
|
||||
image = self._load_img(img)
|
||||
|
||||
if image.width < self.width and image.height < self.height:
|
||||
print(
|
||||
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||
logger.warning(
|
||||
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||
)
|
||||
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
@@ -809,8 +811,8 @@ class Generate:
|
||||
if (image.width * image.height) > (
|
||||
self.width * self.height
|
||||
) and self.size_matters:
|
||||
print(
|
||||
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||
logger.info(
|
||||
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||
)
|
||||
self.size_matters = False
|
||||
|
||||
@@ -891,11 +893,11 @@ class Generate:
|
||||
try:
|
||||
model_data = cache.get_model(model_name)
|
||||
except Exception as e:
|
||||
print(f"** model {model_name} could not be loaded: {str(e)}")
|
||||
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
if previous_model_name is None:
|
||||
raise e
|
||||
print("** trying to reload previous model")
|
||||
logger.warning("trying to reload previous model")
|
||||
model_data = cache.get_model(previous_model_name) # load previous
|
||||
if model_data is None:
|
||||
raise e
|
||||
@@ -962,15 +964,15 @@ class Generate:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@@ -984,7 +986,7 @@ class Generate:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@@ -997,10 +999,10 @@ class Generate:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
@@ -1046,37 +1048,20 @@ class Generate:
|
||||
def _set_scheduler(self):
|
||||
default = self.model.scheduler
|
||||
|
||||
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
||||
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
||||
# provide an alias for compatibility.
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
if self.sampler_name in scheduler_map:
|
||||
sampler_class = scheduler_map[self.sampler_name]
|
||||
if self.sampler_name in SCHEDULER_MAP:
|
||||
sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
|
||||
msg = (
|
||||
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||
)
|
||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||
self.sampler = sampler_class.from_config({**self.model.scheduler.config, **sampler_extra_config})
|
||||
else:
|
||||
msg = (
|
||||
f">> Unsupported Sampler: {self.sampler_name} "
|
||||
f" Unsupported Sampler: {self.sampler_name} "+
|
||||
f"Defaulting to {default}"
|
||||
)
|
||||
self.sampler = default
|
||||
|
||||
print(msg)
|
||||
logger.info(msg)
|
||||
|
||||
if not hasattr(self.sampler, "uses_inpainting_model"):
|
||||
# FIXME: terrible kludge!
|
||||
@@ -1085,17 +1070,17 @@ class Generate:
|
||||
def _load_img(self, img) -> Image:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
print(f">> using provided input image of size {image.width}x{image.height}")
|
||||
logger.info(f"using provided input image of size {image.width}x{image.height}")
|
||||
elif isinstance(img, str):
|
||||
assert os.path.exists(img), f">> {img}: File not found"
|
||||
assert os.path.exists(img), f"{img}: File not found"
|
||||
|
||||
image = Image.open(img)
|
||||
print(
|
||||
f">> loaded input image of size {image.width}x{image.height} from {img}"
|
||||
logger.info(
|
||||
f"loaded input image of size {image.width}x{image.height} from {img}"
|
||||
)
|
||||
else:
|
||||
image = Image.open(img)
|
||||
print(f">> loaded input image of size {image.width}x{image.height}")
|
||||
logger.info(f"loaded input image of size {image.width}x{image.height}")
|
||||
image = ImageOps.exif_transpose(image)
|
||||
return image
|
||||
|
||||
@@ -1183,14 +1168,14 @@ class Generate:
|
||||
|
||||
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
||||
if not mask:
|
||||
print(
|
||||
">> Initial image has transparent areas. Will inpaint in these regions."
|
||||
logger.info(
|
||||
"Initial image has transparent areas. Will inpaint in these regions."
|
||||
)
|
||||
if (not force_outpaint) and self._check_for_erasure(image):
|
||||
print(
|
||||
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
|
||||
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
|
||||
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
|
||||
if (not force_outpaint) and self._check_for_erasure(image):
|
||||
logger.info(
|
||||
"Colors underneath the transparent region seem to have been erased.\n" +
|
||||
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
||||
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
||||
)
|
||||
|
||||
def _squeeze_image(self, image):
|
||||
@@ -1201,11 +1186,11 @@ class Generate:
|
||||
|
||||
def _fit_image(self, image, max_dimensions):
|
||||
w, h = max_dimensions
|
||||
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
|
||||
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||
image = InitImageResizer(image).resize(width=w, height=h)
|
||||
print(
|
||||
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||
logger.info(
|
||||
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||
)
|
||||
return image
|
||||
|
||||
@@ -1216,8 +1201,8 @@ class Generate:
|
||||
) # resize to integer multiple of 64
|
||||
if h != height or w != width:
|
||||
if log:
|
||||
print(
|
||||
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||
logger.info(
|
||||
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||
)
|
||||
height = h
|
||||
width = w
|
||||
|
||||
@@ -25,11 +25,13 @@ from typing import Callable, List, Iterator, Optional, Type
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@@ -70,19 +72,6 @@ class InvokeAIGeneratorOutput:
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
@@ -174,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@@ -225,10 +220,10 @@ class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
@@ -372,7 +367,7 @@ class Generator:
|
||||
try:
|
||||
x_T = self.get_noise(width, height)
|
||||
except:
|
||||
print("** An error occurred while getting initial noise **")
|
||||
logger.error("An error occurred while getting initial noise")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||
@@ -607,7 +602,7 @@ class Generator:
|
||||
image = self.sample_to_image(sample)
|
||||
dirname = os.path.dirname(filepath) or "."
|
||||
if not os.path.exists(dirname):
|
||||
print(f"** creating directory {dirname}")
|
||||
logger.info(f"creating directory {dirname}")
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath, "PNG")
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@@ -72,22 +73,22 @@ class Embiggen(Generator):
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
print(
|
||||
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
print(
|
||||
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
print(
|
||||
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
@@ -97,8 +98,8 @@ class Embiggen(Generator):
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
print(
|
||||
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
@@ -121,8 +122,8 @@ class Embiggen(Generator):
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
@@ -312,10 +313,10 @@ class Embiggen(Generator):
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
print(
|
||||
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
@@ -361,11 +362,11 @@ class Embiggen(Generator):
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
print(
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
@@ -547,8 +548,8 @@ class Embiggen(Generator):
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
print(
|
||||
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
|
||||
@@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -59,7 +60,7 @@ class Inpaint(Img2Img):
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
@@ -75,18 +76,18 @@ class Inpaint(Img2Img):
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: int = None
|
||||
) -> Image:
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a, *tile_size).copy()
|
||||
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
@@ -127,7 +128,9 @@ class Inpaint(Img2Img):
|
||||
|
||||
return si
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
def mask_edge(
|
||||
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||
) -> Image.Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
@@ -206,15 +209,15 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False,
|
||||
@@ -222,7 +225,7 @@ class Inpaint(Img2Img):
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -239,7 +242,7 @@ class Inpaint(Img2Img):
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
if isinstance(init_image, Image.Image):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
@@ -250,8 +253,8 @@ class Inpaint(Img2Img):
|
||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||
)
|
||||
elif infill_method == "solid":
|
||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non-supported infill type {infill_method}", infill_method
|
||||
@@ -269,7 +272,7 @@ class Inpaint(Img2Img):
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
if isinstance(mask_image, Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(
|
||||
mask_image,
|
||||
|
||||
@@ -14,6 +14,8 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
@@ -77,8 +79,8 @@ class Txt2Img2Img(Generator):
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
|
||||
@@ -5,10 +5,9 @@ wraps the actual patchmatch object. It respects the global
|
||||
be suppressed or deferred
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
@@ -28,12 +27,12 @@ class PatchMatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
print(">> Patchmatch initialized")
|
||||
logger.info("Patchmatch initialized")
|
||||
else:
|
||||
print(">> Patchmatch not loaded (nonfatal)")
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
print(">> Patchmatch loading disabled")
|
||||
logger.info("Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -41,7 +41,7 @@ class PngWriter:
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("Dream", dream_prompt)
|
||||
if metadata:
|
||||
info.add_text("invokeai", json.dumps(metadata))
|
||||
info.add_text("sd-metadata", json.dumps(metadata))
|
||||
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
|
||||
return path
|
||||
|
||||
|
||||
@@ -30,9 +30,9 @@ work fine.
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
@@ -83,7 +83,7 @@ class Txt2Mask(object):
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", refined=False):
|
||||
print(">> Initializing clipseg model for text to mask inference")
|
||||
logger.info("Initializing clipseg model for text to mask inference")
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
@@ -101,18 +101,6 @@ class Txt2Mask(object):
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
"""
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
transforms.Resize(
|
||||
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
||||
), # must be multiple of 64...
|
||||
]
|
||||
)
|
||||
|
||||
if type(image) is str:
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager
|
||||
from .model_manager import ModelManager,SDModelComponent
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Union
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
@@ -46,6 +47,7 @@ from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -372,9 +374,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(" | Extracting EMA weights (usually better for inference)")
|
||||
logger.debug("Extracting EMA weights (usually better for inference)")
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
@@ -392,8 +394,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
key
|
||||
)
|
||||
else:
|
||||
print(
|
||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
logger.debug(
|
||||
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
@@ -1115,7 +1117,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print(" | global_step key not found in model")
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
@@ -1208,6 +1210,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
@@ -1229,15 +1233,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
print(f" | Converting VAE {vae_path}")
|
||||
logger.debug(f"Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
print(" | Using checkpoint model's original VAE")
|
||||
logger.debug("Using checkpoint model's original VAE")
|
||||
|
||||
if vae:
|
||||
print(" | Using replacement diffusers VAE")
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
|
||||
@@ -18,18 +18,19 @@ import warnings
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
from typing import Any, Optional, Union, Callable, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
import invokeai.backend.util.logging as logger
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as dlogging,
|
||||
)
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@@ -67,7 +68,7 @@ class SDModelComponent(Enum):
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
@@ -75,6 +76,8 @@ class ModelManager(object):
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
"""
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
@@ -83,6 +86,7 @@ class ModelManager(object):
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
@@ -104,6 +108,7 @@ class ModelManager(object):
|
||||
self.current_model = None
|
||||
self.sequential_offload = sequential_offload
|
||||
self.embedding_path = embedding_path
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
@@ -132,8 +137,8 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return self.current_model
|
||||
|
||||
@@ -144,7 +149,7 @@ class ModelManager(object):
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
@@ -177,7 +182,7 @@ class ModelManager(object):
|
||||
vae from the model currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||
|
||||
|
||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||
@@ -185,12 +190,12 @@ class ModelManager(object):
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||
|
||||
|
||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||
name is provided, return the UNet from the model
|
||||
currently in the GPU.
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||
|
||||
@@ -217,7 +222,7 @@ class ModelManager(object):
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||
|
||||
|
||||
def _get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
@@ -379,7 +384,7 @@ class ModelManager(object):
|
||||
"""
|
||||
omega = self.config
|
||||
if model_name not in omega:
|
||||
print(f"** Unknown model {model_name}")
|
||||
self.logger.error(f"Unknown model {model_name}")
|
||||
return
|
||||
# save these for use in deletion later
|
||||
conf = omega[model_name]
|
||||
@@ -392,13 +397,13 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
if delete_files:
|
||||
if weights:
|
||||
print(f"** Deleting file {weights}")
|
||||
self.logger.info(f"Deleting file {weights}")
|
||||
Path(weights).unlink(missing_ok=True)
|
||||
elif path:
|
||||
print(f"** Deleting directory {path}")
|
||||
self.logger.info(f"Deleting directory {path}")
|
||||
rmtree(path, ignore_errors=True)
|
||||
elif repo_id:
|
||||
print(f"** Deleting the cached model directory for {repo_id}")
|
||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
||||
self._delete_model_from_cache(repo_id)
|
||||
|
||||
def add_model(
|
||||
@@ -439,7 +444,7 @@ class ModelManager(object):
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return
|
||||
@@ -457,7 +462,7 @@ class ModelManager(object):
|
||||
model_format = mconfig.get("format", "ckpt")
|
||||
if model_format == "ckpt":
|
||||
weights = mconfig.weights
|
||||
print(f">> Loading {model_name} from {weights}")
|
||||
self.logger.info(f"Loading {model_name} from {weights}")
|
||||
model, width, height, model_hash = self._load_ckpt_model(
|
||||
model_name, mconfig
|
||||
)
|
||||
@@ -473,13 +478,15 @@ class ModelManager(object):
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
">> Max VRAM used to load the model:",
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
"\n>> Current VRAM usage:"
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
self.logger.info(
|
||||
"Max VRAM used to load the model: "+
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||
)
|
||||
self.logger.info(
|
||||
"Current VRAM usage: "+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
@@ -487,11 +494,11 @@ class ModelManager(object):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
print(f">> Loading diffusers model from {name_or_path}")
|
||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
||||
if using_fp16:
|
||||
print(" | Using faster float16 precision")
|
||||
self.logger.debug("Using faster float16 precision")
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
@@ -523,8 +530,8 @@ class ModelManager(object):
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"** An unexpected error occurred while downloading the model: {e})"
|
||||
self.logger.error(
|
||||
f"An unexpected error occurred while downloading the model: {e})"
|
||||
)
|
||||
if pipeline:
|
||||
break
|
||||
@@ -542,7 +549,7 @@ class ModelManager(object):
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
@@ -559,14 +566,14 @@ class ModelManager(object):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vae_path = None
|
||||
@@ -624,7 +631,7 @@ class ModelManager(object):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
print(f">> Offloading {model_name} to CPU")
|
||||
self.logger.info(f"Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.current_model = None
|
||||
@@ -640,30 +647,26 @@ class ModelManager(object):
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
print(f" | Scanning Model: {model_name}")
|
||||
self.logger.debug(f"Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
|
||||
print(
|
||||
"### WARNING: The model you are trying to load seems to be infected."
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this model.")
|
||||
print("### Please use checkpoints from trusted sources.")
|
||||
print("### Exiting InvokeAI")
|
||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(
|
||||
"\n### WARNING: InvokeAI was unable to scan the model you are using."
|
||||
)
|
||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||
model_safe_check_fail = ask_user(
|
||||
"Do you want to to continue loading the model?", ["y", "n"]
|
||||
)
|
||||
if model_safe_check_fail.lower() != "y":
|
||||
print("### Exiting InvokeAI")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(" | Model scanned ok")
|
||||
self.logger.debug("Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
@@ -780,26 +783,24 @@ class ModelManager(object):
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
|
||||
print(f">> Probing {thing} for import")
|
||||
self.logger.info(f"Probing {thing} for import")
|
||||
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
print(f" | {thing} appears to be a URL")
|
||||
self.logger.info(f"{thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
print(
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
)
|
||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
||||
return
|
||||
else:
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
@@ -810,34 +811,30 @@ class ModelManager(object):
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers model.")
|
||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||
)
|
||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m), commit_to_conf=commit_to_conf
|
||||
):
|
||||
print(f" >> {model_name} successfully imported")
|
||||
self.logger.info(f"{model_name} successfully imported")
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||
return model_name
|
||||
else:
|
||||
print(
|
||||
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
||||
)
|
||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
|
||||
# Model_path is set in the event of a legacy checkpoint file.
|
||||
# If not set, we're all done
|
||||
@@ -845,7 +842,7 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: # already imported
|
||||
print(" | Already imported. Skipping")
|
||||
self.logger.debug("Already imported. Skipping")
|
||||
return model_path.stem
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
@@ -861,39 +858,39 @@ class ModelManager(object):
|
||||
# look for a like-named .yaml file in same directory
|
||||
if model_path.with_suffix(".yaml").exists():
|
||||
model_config_file = model_path.with_suffix(".yaml")
|
||||
print(f" | Using config file {model_config_file.name}")
|
||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
||||
|
||||
else:
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
self.logger.debug("SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
self.logger.debug("SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(" | SD-v2-v model detected")
|
||||
self.logger.debug("SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(" | SD-v2-e model detected")
|
||||
self.logger.debug("SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
self.logger.warning(
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
self.logger.warning(
|
||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -909,7 +906,7 @@ class ModelManager(object):
|
||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||
print(f" | Using VAE file {vae_path.name}")
|
||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
|
||||
diffuser_path = Path(
|
||||
@@ -955,14 +952,14 @@ class ModelManager(object):
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
print(
|
||||
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
self.logger.error(
|
||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
)
|
||||
return
|
||||
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or f"Converted version of {model_name}"
|
||||
print(f" | Converting {model_name} to diffusers (30-60s)")
|
||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
@@ -979,10 +976,10 @@ class ModelManager(object):
|
||||
vae_path=vae_path,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
print(
|
||||
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
||||
self.logger.debug(
|
||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
||||
)
|
||||
print(f" | Writing new config file entry for {model_name}")
|
||||
self.logger.debug(f"Writing new config file entry for {model_name}")
|
||||
new_config = dict(
|
||||
path=str(diffusers_path),
|
||||
description=model_description,
|
||||
@@ -993,17 +990,17 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
print(" | Conversion succeeded")
|
||||
self.logger.debug("Conversion succeeded")
|
||||
except Exception as e:
|
||||
print(f"** Conversion failed: {str(e)}")
|
||||
print(
|
||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
self.logger.warning(f"Conversion failed: {str(e)}")
|
||||
self.logger.warning(
|
||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f">> Finding Models In: {search_folder}")
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||
|
||||
@@ -1027,8 +1024,8 @@ class ModelManager(object):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(
|
||||
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
self.logger.info(
|
||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
)
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
@@ -1036,8 +1033,8 @@ class ModelManager(object):
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
print(
|
||||
">> Current VRAM usage: ",
|
||||
self.logger.info(
|
||||
"Current VRAM usage:"+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
@@ -1126,10 +1123,10 @@ class ModelManager(object):
|
||||
dest = hub / model.stem
|
||||
if dest.exists() and not source.exists():
|
||||
continue
|
||||
print(f"** {source} => {dest}")
|
||||
cls.logger.info(f"{source} => {dest}")
|
||||
if source.exists():
|
||||
if dest.is_symlink():
|
||||
print(f"** Found symlink at {dest.name}. Not migrating.")
|
||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||
elif dest.exists():
|
||||
if source.is_dir():
|
||||
rmtree(source)
|
||||
@@ -1146,7 +1143,7 @@ class ModelManager(object):
|
||||
]
|
||||
for d in empty:
|
||||
os.rmdir(d)
|
||||
print("** Migration is done. Continuing...")
|
||||
cls.logger.info("Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
@@ -1189,15 +1186,15 @@ class ModelManager(object):
|
||||
|
||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||
if self.embedding_path is not None:
|
||||
print(f">> Loading embeddings from {self.embedding_path}")
|
||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
self.logger.info(
|
||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
@@ -1219,7 +1216,7 @@ class ModelManager(object):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
print(" | Calculating sha256 hash of model files")
|
||||
self.logger.debug("Calculating sha256 hash of model files")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
@@ -1231,7 +1228,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
@@ -1249,13 +1246,13 @@ class ModelManager(object):
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(" | Calculating sha256 hash of weights file")
|
||||
self.logger.debug("Calculating sha256 hash of weights file")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
@@ -1276,12 +1273,12 @@ class ModelManager(object):
|
||||
local_files_only=not Globals.internet_available,
|
||||
)
|
||||
|
||||
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
print(" | Using more accurate float32 precision")
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
@@ -1305,12 +1302,12 @@ class ModelManager(object):
|
||||
break
|
||||
|
||||
if not vae and deferred_error:
|
||||
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
|
||||
return vae
|
||||
|
||||
@staticmethod
|
||||
def _delete_model_from_cache(repo_id):
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
@@ -1321,8 +1318,8 @@ class ModelManager(object):
|
||||
for revision in repo.revisions:
|
||||
hashes_to_delete.add(revision.commit_hash)
|
||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||
print(
|
||||
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
cls.logger.warning(
|
||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||
)
|
||||
strategy.execute()
|
||||
|
||||
|
||||
@@ -16,66 +16,58 @@ from compel.prompt_parser import (
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(
|
||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
||||
):
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_string
|
||||
)
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False
|
||||
)
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
@@ -86,18 +78,17 @@ def get_prompt_structure(
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
@@ -162,8 +153,8 @@ def log_tokenization(
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
@@ -237,29 +228,28 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
print(f"{tokenized}\x1b[0m")
|
||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
logger.debug(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
print(f"{discarded}\x1b[0m")
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(
|
||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
||||
)
|
||||
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
@@ -295,8 +285,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
logger.warning(
|
||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Restoration:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
@@ -8,17 +10,17 @@ class Restoration:
|
||||
# Load GFPGAN
|
||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||
if gfpgan.gfpgan_model_exists:
|
||||
print(">> GFPGAN Initialized")
|
||||
logger.info("GFPGAN Initialized")
|
||||
else:
|
||||
print(">> GFPGAN Disabled")
|
||||
logger.info("GFPGAN Disabled")
|
||||
gfpgan = None
|
||||
|
||||
# Load CodeFormer
|
||||
codeformer = self.load_codeformer()
|
||||
if codeformer.codeformer_model_exists:
|
||||
print(">> CodeFormer Initialized")
|
||||
logger.info("CodeFormer Initialized")
|
||||
else:
|
||||
print(">> CodeFormer Disabled")
|
||||
logger.info("CodeFormer Disabled")
|
||||
codeformer = None
|
||||
|
||||
return gfpgan, codeformer
|
||||
@@ -39,5 +41,5 @@ class Restoration:
|
||||
from .realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN(esrgan_bg_tile)
|
||||
print(">> ESRGAN Initialized")
|
||||
logger.info("ESRGAN Initialized")
|
||||
return esrgan
|
||||
|
||||
@@ -5,6 +5,7 @@ import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..globals import Globals
|
||||
|
||||
pretrained_model_url = (
|
||||
@@ -23,12 +24,12 @@ class CodeFormerRestoration:
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
sys.path.append(os.path.abspath(codeformer_dir))
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
@@ -97,7 +98,7 @@ class CodeFormerRestoration:
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f"\tFailed inference for CodeFormer: {error}.")
|
||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype("uint8")
|
||||
|
||||
@@ -6,9 +6,9 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class GFPGAN:
|
||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
@@ -19,7 +19,7 @@ class GFPGAN:
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||
return None
|
||||
|
||||
def model_exists(self):
|
||||
@@ -27,7 +27,7 @@ class GFPGAN:
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
@@ -47,14 +47,14 @@ class GFPGAN:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print(">> Error loading GFPGAN:", file=sys.stderr)
|
||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
os.chdir(cwd)
|
||||
|
||||
if self.gfpgan is None:
|
||||
print(f">> WARNING: GFPGAN not initialized.")
|
||||
print(
|
||||
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
logger.warning("WARNING: GFPGAN not initialized.")
|
||||
logger.warning(
|
||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||
)
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Outcrop(object):
|
||||
def __init__(
|
||||
@@ -82,7 +82,7 @@ class Outcrop(object):
|
||||
pixels = extents[direction]
|
||||
# round pixels up to the nearest 64
|
||||
pixels = math.ceil(pixels / 64) * 64
|
||||
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
||||
image = self._rotate(image, direction)
|
||||
image = self._extend(image, pixels)
|
||||
image = self._rotate(image, direction, reverse=True)
|
||||
|
||||
@@ -6,18 +6,13 @@ import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ESRGAN:
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
@@ -74,16 +69,16 @@ class ESRGAN:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
||||
logger.error("Error loading Real-ESRGAN:")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if upsampler_scale == 0:
|
||||
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||
return image
|
||||
|
||||
if seed is not None:
|
||||
print(
|
||||
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
logger.info(
|
||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||
)
|
||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -14,6 +14,7 @@ from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .globals import global_cache_dir
|
||||
from .util import CPU_DEVICE
|
||||
|
||||
@@ -40,8 +41,8 @@ class SafetyChecker(object):
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
logger.error(
|
||||
"An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
|
||||
@@ -65,8 +66,8 @@ class SafetyChecker(object):
|
||||
)
|
||||
self.safety_checker.to(CPU_DEVICE) # offload
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
logger.warning(
|
||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
|
||||
@@ -17,6 +17,7 @@ from huggingface_hub import (
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
@@ -57,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
else:
|
||||
elif Globals.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
@@ -66,13 +67,15 @@ class HuggingFaceConceptsLibrary(object):
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
print(
|
||||
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
print(
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
@@ -81,7 +84,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
print(
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
@@ -219,7 +222,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
print(f">> Downloading {repo_id}...", end="")
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
@@ -233,22 +236,22 @@ class HuggingFaceConceptsLibrary(object):
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
print(
|
||||
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
print("...{:.2f}Kb".format(bytes / 1024))
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
|
||||
@@ -445,8 +445,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@property
|
||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
values = [getattr(self, name) for name in module_names.keys()]
|
||||
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||
submodels = []
|
||||
for name in module_names.keys():
|
||||
if hasattr(self, name):
|
||||
value = getattr(self, name)
|
||||
else:
|
||||
value = getattr(self.config, name)
|
||||
if isinstance(value, torch.nn.Module):
|
||||
submodels.append(value)
|
||||
return submodels
|
||||
|
||||
def image_from_embeddings(
|
||||
self,
|
||||
@@ -502,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, device=self._model_group.device_for(self.unet)
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
@@ -538,13 +548,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
step=-1,
|
||||
timestep=self.scheduler.num_train_timesteps,
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
@@ -718,12 +729,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
noise: torch.Tensor,
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps,
|
||||
strength,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||
@@ -749,13 +756,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(
|
||||
self, num_inference_steps: int, strength: float, device
|
||||
self, num_inference_steps: int, strength: float, device=None
|
||||
) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
num_inference_steps, strength, device=scheduler_device
|
||||
)
|
||||
# Workaround for low strength resulting in zero timesteps.
|
||||
# TODO: submit upstream fix for zero-step img2img
|
||||
@@ -789,9 +802,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
@@ -915,7 +926,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.in_channels
|
||||
return self.unet.config.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
|
||||
@@ -10,13 +10,13 @@ import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
@@ -188,7 +188,7 @@ class Context:
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
@@ -209,7 +209,7 @@ class InvokeAICrossAttentionMixin:
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||
`module` is the current Attention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
@@ -345,16 +345,15 @@ class InvokeAICrossAttentionMixin:
|
||||
def restore_default_cross_attention(
|
||||
model,
|
||||
is_running_diffusers: bool,
|
||||
restore_attention_processor: Optional[AttnProcessor] = None,
|
||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
||||
):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@@ -373,47 +372,29 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next(
|
||||
(
|
||||
p.slice_size
|
||||
for p in old_attn_processors.values()
|
||||
if type(p) is SlicedAttnProcessor
|
||||
),
|
||||
default_slice_size,
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
def get_cross_attention_modules(
|
||||
model, which: CrossAttentionType
|
||||
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||
|
||||
cross_attention_class: type = (
|
||||
InvokeAIDiffusersCrossAttention
|
||||
if isinstance(model, UNet2DConditionModel)
|
||||
else CrossAttention
|
||||
)
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [
|
||||
@@ -425,13 +406,13 @@ def get_cross_attention_modules(
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
print(
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||
+ f"work properly until it is fixed."
|
||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||
+ "work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
@@ -550,7 +531,7 @@ def get_mem_free_total(device):
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(
|
||||
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
|
||||
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -572,8 +553,8 @@ class InvokeAIDiffusersCrossAttention(
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class CrossAttnProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
class AttnProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
@@ -601,9 +582,9 @@ class CrossAttnProcessor:
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from diffusers.models.cross_attention import (
|
||||
CrossAttention,
|
||||
CrossAttnProcessor,
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
|
||||
@@ -653,7 +634,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: CrossAttention,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
|
||||
@@ -5,9 +5,11 @@ from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from .cross_attention_control import (
|
||||
@@ -16,8 +18,8 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
override_cross_attention,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
@@ -78,30 +80,41 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
cls,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
):
|
||||
do_swap = (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def override_cross_attention(
|
||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
) -> Dict[str, AttnProcessor]:
|
||||
) -> Dict[str, AttentionProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
@@ -118,7 +131,7 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttnProcessor"] = None
|
||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
@@ -262,7 +275,7 @@ class InvokeAIDiffuserComponent:
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
@@ -466,10 +479,14 @@ class InvokeAIDiffuserComponent:
|
||||
outside = torch.count_nonzero(
|
||||
(latents < -current_threshold) | (latents > current_threshold)
|
||||
)
|
||||
print(
|
||||
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
logger.info(
|
||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
||||
)
|
||||
logger.debug(
|
||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||
)
|
||||
logger.debug(
|
||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
@@ -496,9 +513,11 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
if self.debug_thresholding:
|
||||
print(
|
||||
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
logger.debug(
|
||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
||||
)
|
||||
logger.debug(
|
||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
||||
)
|
||||
|
||||
return latents
|
||||
@@ -599,7 +618,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings) / 2)
|
||||
|
||||
@@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
||||
|
||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def mkdirs(paths):
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + "_archived_" + get_timestamp()
|
||||
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
os.replace(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
||||
23
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
23
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
@@ -59,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
@@ -85,8 +86,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
print(
|
||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
logger.warning(
|
||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -105,8 +106,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
print(
|
||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
logger.info(
|
||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
@@ -120,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f" | The error was {str(e)}")
|
||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
@@ -133,8 +134,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
print(
|
||||
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
logger.warning(
|
||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
@@ -155,11 +156,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith("Warning"):
|
||||
print(f">> {str(e)}")
|
||||
logger.warning(f"{str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
logger.error(
|
||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -219,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
print(
|
||||
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
logger.info(
|
||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
logger.debug(
|
||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
print(f" | The error was {str(e)}")
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
@@ -306,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
print(
|
||||
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||
logger.critical(
|
||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
@@ -334,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v1 embedding file: {basename}')
|
||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
@@ -342,7 +343,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = f'<basename>'
|
||||
trigger = '<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
@@ -365,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v2 embedding file: {basename}')
|
||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
@@ -384,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
print(f" ** {basename}: Unrecognized embedding format")
|
||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -393,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v3 embedding file: {basename}')
|
||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
@@ -411,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
print(f' | Loading v4 embedding file: {short_path}')
|
||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(f" ** Invalid embeddings file: {short_path}")
|
||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
|
||||
110
invokeai/backend/util/logging.py
Normal file
110
invokeai/backend/util/logging.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||
|
||||
"""invokeai.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||
(or)
|
||||
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||
|
||||
logger.critical('this is critical') // Critical Message
|
||||
logger.error('this is an error') // Error Message
|
||||
logger.warning('this is a warning') // Warning Message
|
||||
logger.info('this is info') // Info Message
|
||||
logger.debug('this is debugging') // Debug Message
|
||||
|
||||
Console messages:
|
||||
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
|
||||
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
|
||||
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
|
||||
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
|
||||
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
|
||||
|
||||
Alternate Method (in this case the logger name will be set to InvokeAI):
|
||||
import invokeai.backend.util.logging as IAILogger
|
||||
IAILogger.debug('this is a debugging message')
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
||||
|
||||
def log(level, msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
|
||||
|
||||
def disable(level=logging.CRITICAL):
|
||||
InvokeAILogger.getLogger().disable(level)
|
||||
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||
|
||||
def getLogger(name: str = None) -> logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
|
||||
|
||||
class InvokeAILogFormatter(logging.Formatter):
|
||||
'''
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
'''
|
||||
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
cyan = "\x1b[36;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
|
||||
# Log Format
|
||||
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
|
||||
|
||||
# Format Map
|
||||
FORMATS = {
|
||||
logging.DEBUG: cyan + format + reset,
|
||||
logging.INFO: grey + format + reset,
|
||||
logging.WARNING: yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
|
||||
if name not in self.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ch = logging.StreamHandler()
|
||||
fmt = InvokeAILogFormatter()
|
||||
ch.setFormatter(fmt)
|
||||
logger.addHandler(ch)
|
||||
self.loggers[name] = logger
|
||||
return self.loggers[name]
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .devices import torch_dtype
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
logger.warning("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -80,8 +81,8 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||
logger.debug(
|
||||
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||
)
|
||||
return total_params
|
||||
|
||||
@@ -132,8 +133,8 @@ def parallel_data_prefetch(
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
print(
|
||||
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
logger.warning(
|
||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
@@ -175,7 +176,7 @@ def parallel_data_prefetch(
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
print("Start prefetching...")
|
||||
logger.info("Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
@@ -194,7 +195,7 @@ def parallel_data_prefetch(
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e)
|
||||
logger.error("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
@@ -202,7 +203,7 @@ def parallel_data_prefetch(
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == "ndarray":
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
@@ -318,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
print("* corrupt existing file found. re-downloading")
|
||||
logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or exist_size == content_length:
|
||||
print(f"* {dest}: complete file found. Skipping.")
|
||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||
else:
|
||||
print(f"* {dest}: Downloading...")
|
||||
logger.error(f"{dest}: Downloading...")
|
||||
|
||||
try:
|
||||
if content_length < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(dest, open_mode) as file, tqdm(
|
||||
@@ -349,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
return None
|
||||
|
||||
return dest
|
||||
|
||||
@@ -19,6 +19,7 @@ from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
import invokeai.frontend.web.dist as frontend
|
||||
|
||||
from .. import Generate
|
||||
@@ -77,7 +78,6 @@ class InvokeAIWebServer:
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
# Socket IO
|
||||
logger = True if args.web_verbose else False
|
||||
engineio_logger = True if args.web_verbose else False
|
||||
max_http_buffer_size = 10000000
|
||||
|
||||
@@ -213,7 +213,7 @@ class InvokeAIWebServer:
|
||||
self.load_socketio_listeners(self.socketio)
|
||||
|
||||
if args.gui:
|
||||
print(">> Launching Invoke AI GUI")
|
||||
logger.info("Launching Invoke AI GUI")
|
||||
try:
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
@@ -231,17 +231,17 @@ class InvokeAIWebServer:
|
||||
sys.exit(0)
|
||||
else:
|
||||
useSSL = args.certfile or args.keyfile
|
||||
print(">> Started Invoke AI Web Server")
|
||||
logger.info("Started Invoke AI Web Server")
|
||||
if self.host == "0.0.0.0":
|
||||
print(
|
||||
logger.info(
|
||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||
logger.info(
|
||||
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||
)
|
||||
print(
|
||||
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||
logger.info(
|
||||
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||
)
|
||||
if not useSSL:
|
||||
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
||||
@@ -273,7 +273,7 @@ class InvokeAIWebServer:
|
||||
# path for thumbnail images
|
||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||
# txt log
|
||||
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
||||
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
|
||||
# make all output paths
|
||||
[
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -290,7 +290,7 @@ class InvokeAIWebServer:
|
||||
def load_socketio_listeners(self, socketio):
|
||||
@socketio.on("requestSystemConfig")
|
||||
def handle_request_capabilities():
|
||||
print(">> System config requested")
|
||||
logger.info("System config requested")
|
||||
config = self.get_system_config()
|
||||
config["model_list"] = self.generate.model_manager.list_models()
|
||||
config["infill_methods"] = infill_methods()
|
||||
@@ -330,7 +330,7 @@ class InvokeAIWebServer:
|
||||
if model_name in current_model_list:
|
||||
update = True
|
||||
|
||||
print(f">> Adding New Model: {model_name}")
|
||||
logger.info(f"Adding New Model: {model_name}")
|
||||
|
||||
self.generate.model_manager.add_model(
|
||||
model_name=model_name,
|
||||
@@ -348,14 +348,14 @@ class InvokeAIWebServer:
|
||||
"update": update,
|
||||
},
|
||||
)
|
||||
print(f">> New Model Added: {model_name}")
|
||||
logger.info(f"New Model Added: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@socketio.on("deleteModel")
|
||||
def handle_delete_model(model_name: str):
|
||||
try:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
self.generate.model_manager.del_model(model_name)
|
||||
self.generate.model_manager.commit(opt.conf)
|
||||
updated_model_list = self.generate.model_manager.list_models()
|
||||
@@ -366,14 +366,14 @@ class InvokeAIWebServer:
|
||||
"model_list": updated_model_list,
|
||||
},
|
||||
)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@socketio.on("requestModelChange")
|
||||
def handle_set_model(model_name: str):
|
||||
try:
|
||||
print(f">> Model change requested: {model_name}")
|
||||
logger.info(f"Model change requested: {model_name}")
|
||||
model = self.generate.set_model(model_name)
|
||||
model_list = self.generate.model_manager.list_models()
|
||||
if model is None:
|
||||
@@ -454,7 +454,7 @@ class InvokeAIWebServer:
|
||||
"update": True,
|
||||
},
|
||||
)
|
||||
print(f">> Model Converted: {model_name}")
|
||||
logger.info(f"Model Converted: {model_name}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@@ -490,7 +490,7 @@ class InvokeAIWebServer:
|
||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
"vae", None
|
||||
):
|
||||
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
||||
merged_model_config.update(vae=vae)
|
||||
|
||||
self.generate.model_manager.import_diffuser_model(
|
||||
@@ -507,8 +507,8 @@ class InvokeAIWebServer:
|
||||
"update": True,
|
||||
},
|
||||
)
|
||||
print(f">> Models Merged: {models_to_merge}")
|
||||
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
logger.info(f"Models Merged: {models_to_merge}")
|
||||
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
|
||||
@@ -698,7 +698,7 @@ class InvokeAIWebServer:
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f">> Unable to load {path}")
|
||||
logger.info(f"Unable to load {path}")
|
||||
socketio.emit(
|
||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||
)
|
||||
@@ -735,9 +735,9 @@ class InvokeAIWebServer:
|
||||
printable_parameters["init_mask"][:64] + "..."
|
||||
)
|
||||
|
||||
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||
print(f">> ESRGAN Parameters: {esrgan_parameters}")
|
||||
print(f">> Facetool Parameters: {facetool_parameters}")
|
||||
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
||||
logger.info(f"Facetool Parameters: {facetool_parameters}")
|
||||
|
||||
self.generate_images(
|
||||
generation_parameters,
|
||||
@@ -750,8 +750,8 @@ class InvokeAIWebServer:
|
||||
@socketio.on("runPostprocessing")
|
||||
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
||||
try:
|
||||
print(
|
||||
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||
logger.info(
|
||||
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||
)
|
||||
|
||||
progress = Progress()
|
||||
@@ -861,14 +861,14 @@ class InvokeAIWebServer:
|
||||
|
||||
@socketio.on("cancel")
|
||||
def handle_cancel():
|
||||
print(">> Cancel processing requested")
|
||||
logger.info("Cancel processing requested")
|
||||
self.canceled.set()
|
||||
|
||||
# TODO: I think this needs a safety mechanism.
|
||||
@socketio.on("deleteImage")
|
||||
def handle_delete_image(url, thumbnail, uuid, category):
|
||||
try:
|
||||
print(f'>> Delete requested "{url}"')
|
||||
logger.info(f'Delete requested "{url}"')
|
||||
from send2trash import send2trash
|
||||
|
||||
path = self.get_image_path_from_url(url)
|
||||
@@ -1263,7 +1263,7 @@ class InvokeAIWebServer:
|
||||
image, os.path.basename(path), self.thumbnail_image_path
|
||||
)
|
||||
|
||||
print(f'\n\n>> Image generated: "{path}"\n')
|
||||
logger.info(f'Image generated: "{path}"\n')
|
||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||
|
||||
if progress.total_iterations > progress.current_iteration:
|
||||
@@ -1329,7 +1329,7 @@ class InvokeAIWebServer:
|
||||
except Exception as e:
|
||||
# Clear the CUDA cache on an exception
|
||||
self.empty_cuda_cache()
|
||||
print(e)
|
||||
logger.error(e)
|
||||
self.handle_exceptions(e)
|
||||
|
||||
def empty_cuda_cache(self):
|
||||
|
||||
@@ -4,17 +4,21 @@ from .parse_seed_weights import parse_seed_weights
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
'heun_k',
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ if sys.platform == "darwin":
|
||||
import pyparsing # type: ignore
|
||||
|
||||
import invokeai.version as invokeai
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend import Generate, ModelManager
|
||||
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
||||
@@ -69,7 +70,7 @@ def main():
|
||||
# run any post-install patches needed
|
||||
run_patches()
|
||||
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
@@ -78,8 +79,8 @@ def main():
|
||||
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
)
|
||||
|
||||
print(f">> {invokeai.__app_name__}, version {invokeai.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# loading here to avoid long delays on startup
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
@@ -121,7 +122,7 @@ def main():
|
||||
else:
|
||||
raise FileNotFoundError(f"{opt.infile} not found.")
|
||||
except (FileNotFoundError, IOError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
logger.critical('Aborted',exc_info=True)
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
@@ -142,12 +143,12 @@ def main():
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt, e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
except (IOError, KeyError):
|
||||
logger.critical("Aborted",exc_info=True)
|
||||
sys.exit(-1)
|
||||
|
||||
if opt.seamless:
|
||||
print(">> changed to seamless tiling mode")
|
||||
logger.info("Changed to seamless tiling mode")
|
||||
|
||||
# preload the model
|
||||
try:
|
||||
@@ -180,9 +181,7 @@ def main():
|
||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||
)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
|
||||
logger.error("An error occurred",exc_info=True)
|
||||
|
||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||
def main_loop(gen, opt):
|
||||
@@ -248,7 +247,7 @@ def main_loop(gen, opt):
|
||||
if not opt.prompt:
|
||||
oldargs = metadata_from_png(opt.init_img)
|
||||
opt.prompt = oldargs.prompt
|
||||
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||
except (OSError, AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
@@ -265,9 +264,9 @@ def main_loop(gen, opt):
|
||||
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
||||
try:
|
||||
opt.init_img = last_results[int(opt.init_img)][0]
|
||||
print(f">> Reusing previous image {opt.init_img}")
|
||||
logger.info(f"Reusing previous image {opt.init_img}")
|
||||
except IndexError:
|
||||
print(f">> No previous initial image at position {opt.init_img} found")
|
||||
logger.info(f"No previous initial image at position {opt.init_img} found")
|
||||
opt.init_img = None
|
||||
continue
|
||||
|
||||
@@ -288,9 +287,9 @@ def main_loop(gen, opt):
|
||||
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
||||
try:
|
||||
opt.seed = last_results[opt.seed][1]
|
||||
print(f">> Reusing previous seed {opt.seed}")
|
||||
logger.info(f"Reusing previous seed {opt.seed}")
|
||||
except IndexError:
|
||||
print(f">> No previous seed at position {opt.seed} found")
|
||||
logger.info(f"No previous seed at position {opt.seed} found")
|
||||
opt.seed = None
|
||||
continue
|
||||
|
||||
@@ -309,7 +308,7 @@ def main_loop(gen, opt):
|
||||
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
||||
current_outdir = os.path.join(opt.outdir, subdir)
|
||||
|
||||
print('Writing files to directory: "' + current_outdir + '"')
|
||||
logger.info('Writing files to directory: "' + current_outdir + '"')
|
||||
|
||||
# make sure the output directory exists
|
||||
if not os.path.exists(current_outdir):
|
||||
@@ -438,15 +437,14 @@ def main_loop(gen, opt):
|
||||
catch_interrupts=catch_ctrl_c,
|
||||
**vars(opt),
|
||||
)
|
||||
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
|
||||
print("** An error occurred while processing your prompt **")
|
||||
print(f"** {str(e)} **")
|
||||
except (PromptParser.ParsingException, pyparsing.ParseException):
|
||||
logger.error("An error occurred while processing your prompt",exc_info=True)
|
||||
elif operation == "postprocess":
|
||||
print(f">> fixing {opt.prompt}")
|
||||
logger.info(f"fixing {opt.prompt}")
|
||||
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
||||
|
||||
elif operation == "mask":
|
||||
print(f">> generating masks from {opt.prompt}")
|
||||
logger.info(f"generating masks from {opt.prompt}")
|
||||
do_textmask(gen, opt, image_writer)
|
||||
|
||||
if opt.grid and len(grid_images) > 0:
|
||||
@@ -469,12 +467,12 @@ def main_loop(gen, opt):
|
||||
)
|
||||
results = [[path, formatted_dream_prompt]]
|
||||
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
except AssertionError:
|
||||
logger.error(e)
|
||||
continue
|
||||
|
||||
except OSError as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
continue
|
||||
|
||||
print("Outputs:")
|
||||
@@ -513,7 +511,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
gen.set_model(model_name)
|
||||
add_embedding_terms(gen, completer)
|
||||
except KeyError as e:
|
||||
print(str(e))
|
||||
logger.error(e)
|
||||
except Exception as e:
|
||||
report_model_error(opt, e)
|
||||
completer.add_history(command)
|
||||
@@ -527,8 +525,8 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith("!import"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print(
|
||||
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||
logger.warning(
|
||||
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -541,7 +539,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith(("!convert", "!optimize")):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the path to a .ckpt or .safetensors model")
|
||||
logger.warning("please provide the path to a .ckpt or .safetensors model")
|
||||
else:
|
||||
try:
|
||||
convert_model(path[1], gen, opt, completer)
|
||||
@@ -553,7 +551,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith("!edit"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the name of a model")
|
||||
logger.warning("please provide the name of a model")
|
||||
else:
|
||||
edit_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
@@ -562,7 +560,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
elif command.startswith("!del"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the name of a model")
|
||||
logger.warning("please provide the name of a model")
|
||||
else:
|
||||
del_config(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
@@ -642,8 +640,8 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
try:
|
||||
default_name = url_attachment_name(model_path)
|
||||
default_name = Path(default_name).stem
|
||||
except Exception as e:
|
||||
print(f"** URL: {str(e)}")
|
||||
except Exception:
|
||||
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
|
||||
model_name, model_desc = _get_model_name_and_desc(
|
||||
gen.model_manager,
|
||||
completer,
|
||||
@@ -664,11 +662,11 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
model_config_file=config_file,
|
||||
)
|
||||
if not imported_name:
|
||||
print("** Aborting import.")
|
||||
logger.error("Aborting import.")
|
||||
return
|
||||
|
||||
if not _verify_load(imported_name, gen):
|
||||
print("** model failed to load. Discarding configuration entry")
|
||||
logger.error("model failed to load. Discarding configuration entry")
|
||||
gen.model_manager.del_model(imported_name)
|
||||
return
|
||||
if click.confirm("Make this the default model?", default=False):
|
||||
@@ -676,7 +674,7 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f">> {imported_name} successfully installed")
|
||||
logger.info(f"{imported_name} successfully installed")
|
||||
|
||||
def _pick_configuration_file(completer)->Path:
|
||||
print(
|
||||
@@ -720,21 +718,21 @@ Please select the type of this model:
|
||||
return choice
|
||||
|
||||
def _verify_load(model_name: str, gen) -> bool:
|
||||
print(">> Verifying that new model loads...")
|
||||
logger.info("Verifying that new model loads...")
|
||||
current_model = gen.model_name
|
||||
try:
|
||||
if not gen.set_model(model_name):
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"** model failed to load: {str(e)}")
|
||||
print(
|
||||
logger.warning(f"model failed to load: {str(e)}")
|
||||
logger.warning(
|
||||
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
||||
)
|
||||
return False
|
||||
if click.confirm("Keep model loaded?", default=True):
|
||||
gen.set_model(model_name)
|
||||
else:
|
||||
print(">> Restoring previous model")
|
||||
logger.info("Restoring previous model")
|
||||
gen.set_model(current_model)
|
||||
return True
|
||||
|
||||
@@ -757,7 +755,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
ckpt_path = None
|
||||
original_config_file = None
|
||||
if model_name_or_path == gen.model_name:
|
||||
print("** Can't convert the active model. !switch to another model first. **")
|
||||
logger.warning("Can't convert the active model. !switch to another model first. **")
|
||||
return
|
||||
elif model_info := manager.model_info(model_name_or_path):
|
||||
if "weights" in model_info:
|
||||
@@ -767,7 +765,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
model_description = model_info["description"]
|
||||
vae_path = model_info.get("vae")
|
||||
else:
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
model_name = manager.convert_and_import(
|
||||
ckpt_path,
|
||||
@@ -788,16 +786,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
manager.commit(opt.conf)
|
||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f"{ckpt_path} deleted")
|
||||
logger.warning(f"{ckpt_path} deleted")
|
||||
|
||||
|
||||
def del_config(model_name: str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
if model_name == current_model:
|
||||
print("** Can't delete active model. !switch to another model first. **")
|
||||
logger.warning("Can't delete active model. !switch to another model first. **")
|
||||
return
|
||||
if model_name not in gen.model_manager.config:
|
||||
print(f"** Unknown model {model_name}")
|
||||
logger.warning(f"Unknown model {model_name}")
|
||||
return
|
||||
|
||||
if not click.confirm(
|
||||
@@ -810,17 +808,17 @@ def del_config(model_name: str, gen, opt, completer):
|
||||
)
|
||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f"** {model_name} deleted")
|
||||
logger.warning(f"{model_name} deleted")
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
|
||||
|
||||
def edit_model(model_name: str, gen, opt, completer):
|
||||
manager = gen.model_manager
|
||||
if not (info := manager.model_info(model_name)):
|
||||
print(f"** Unknown model {model_name}")
|
||||
logger.warning(f"** Unknown model {model_name}")
|
||||
return
|
||||
|
||||
print(f"\n>> Editing model {model_name} from configuration file {opt.conf}")
|
||||
print()
|
||||
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
|
||||
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
||||
|
||||
for attribute in info.keys():
|
||||
@@ -858,7 +856,7 @@ def edit_model(model_name: str, gen, opt, completer):
|
||||
manager.set_default_model(new_name)
|
||||
manager.commit(opt.conf)
|
||||
completer.update_models(manager.list_models())
|
||||
print(">> Model successfully updated")
|
||||
logger.info("Model successfully updated")
|
||||
|
||||
|
||||
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||
@@ -869,11 +867,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||
if len(model_name) == 0:
|
||||
model_name = default_name
|
||||
if not re.match("^[\w._+:/-]+$", model_name):
|
||||
print(
|
||||
'** model name must contain only words, digits and the characters "._+:/-" **'
|
||||
logger.warning(
|
||||
'model name must contain only words, digits and the characters "._+:/-" **'
|
||||
)
|
||||
elif model_name != default_name and model_name in existing_names:
|
||||
print(f"** the name {model_name} is already in use. Pick another.")
|
||||
logger.warning(f"the name {model_name} is already in use. Pick another.")
|
||||
else:
|
||||
done = True
|
||||
return model_name
|
||||
@@ -940,11 +938,10 @@ def do_postprocess(gen, opt, callback):
|
||||
opt=opt,
|
||||
)
|
||||
except OSError:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"** {file_path}: file could not be read")
|
||||
logger.error(f"{file_path}: file could not be read",exc_info=True)
|
||||
return
|
||||
except (KeyError, AttributeError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
|
||||
return
|
||||
return opt.last_operation
|
||||
|
||||
@@ -999,13 +996,13 @@ def prepare_image_metadata(
|
||||
try:
|
||||
filename = opt.fnformat.format(**wildcards)
|
||||
except KeyError as e:
|
||||
print(
|
||||
f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||
logger.error(
|
||||
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||
)
|
||||
filename = f"{prefix}.{seed}.png"
|
||||
except IndexError:
|
||||
print(
|
||||
"** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||
logger.error(
|
||||
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||
)
|
||||
filename = f"{prefix}.{seed}.png"
|
||||
|
||||
@@ -1094,14 +1091,14 @@ def split_variations(variations_string) -> list:
|
||||
for part in variations_string.split(","):
|
||||
seed_and_weight = part.split(":")
|
||||
if len(seed_and_weight) != 2:
|
||||
print(f'** Could not parse with_variation part "{part}"')
|
||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
try:
|
||||
seed = int(seed_and_weight[0])
|
||||
weight = float(seed_and_weight[1])
|
||||
except ValueError:
|
||||
print(f'** Could not parse with_variation part "{part}"')
|
||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
||||
broken = True
|
||||
break
|
||||
parts.append([seed, weight])
|
||||
@@ -1125,23 +1122,23 @@ def load_face_restoration(opt):
|
||||
opt.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
logger.info("Face restoration disabled")
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
return gfpgan, codeformer, esrgan
|
||||
|
||||
|
||||
def make_step_callback(gen, opt, prefix):
|
||||
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Intermediate images will be written into {destination}")
|
||||
logger.info(f"Intermediate images will be written into {destination}")
|
||||
|
||||
def callback(state: PipelineIntermediateState):
|
||||
latents = state.latents
|
||||
@@ -1183,21 +1180,20 @@ def retrieve_dream_command(opt, command, completer):
|
||||
try:
|
||||
cmd = dream_cmd_from_png(path)
|
||||
except OSError:
|
||||
print(f"## {tokens[0]}: file could not be read")
|
||||
logger.error(f"{tokens[0]}: file could not be read")
|
||||
except (KeyError, AttributeError, IndexError):
|
||||
print(f"## {tokens[0]}: file has no metadata")
|
||||
logger.error(f"{tokens[0]}: file has no metadata")
|
||||
except:
|
||||
print(f"## {tokens[0]}: file could not be processed")
|
||||
logger.error(f"{tokens[0]}: file could not be processed")
|
||||
if len(cmd) > 0:
|
||||
completer.set_line(cmd)
|
||||
|
||||
|
||||
def write_commands(opt, file_path: str, outfilepath: str):
|
||||
dir, basename = os.path.split(file_path)
|
||||
try:
|
||||
paths = sorted(list(Path(dir).glob(basename)))
|
||||
except ValueError:
|
||||
print(f'## "{basename}": unacceptable pattern')
|
||||
logger.error(f'"{basename}": unacceptable pattern')
|
||||
return
|
||||
|
||||
commands = []
|
||||
@@ -1206,9 +1202,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
||||
try:
|
||||
cmd = dream_cmd_from_png(path)
|
||||
except (KeyError, AttributeError, IndexError):
|
||||
print(f"## {path}: file has no metadata")
|
||||
logger.error(f"{path}: file has no metadata")
|
||||
except:
|
||||
print(f"## {path}: file could not be processed")
|
||||
logger.error(f"{path}: file could not be processed")
|
||||
if cmd:
|
||||
commands.append(f"# {path}")
|
||||
commands.append(cmd)
|
||||
@@ -1218,18 +1214,18 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
||||
outfilepath = os.path.join(opt.outdir, basename)
|
||||
with open(outfilepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(commands))
|
||||
print(f">> File {outfilepath} with commands created")
|
||||
logger.info(f"File {outfilepath} with commands created")
|
||||
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.warning(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
if not click.confirm(
|
||||
@@ -1238,7 +1234,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
@@ -1255,7 +1251,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
from ..install import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
print("** InvokeAI will now restart")
|
||||
logger.warning("InvokeAI will now restart")
|
||||
sys.argv = previous_args
|
||||
main() # would rather do a os.exec(), but doesn't exist?
|
||||
sys.exit(0)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
'''
|
||||
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
||||
pip install <path_to_git_source>.
|
||||
"""
|
||||
'''
|
||||
import os
|
||||
import platform
|
||||
|
||||
import requests
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
@@ -16,8 +15,10 @@ from rich.text import Text
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
INVOKE_AI_SRC = "https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_REL = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
|
||||
INVOKE_AI_TAG="https://github.com/invoke-ai/InvokeAI/archive/refs/tags"
|
||||
INVOKE_AI_BRANCH="https://github.com/invoke-ai/InvokeAI/archive/refs/heads"
|
||||
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
|
||||
OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
@@ -28,22 +29,22 @@ if OS == "Windows":
|
||||
else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def get_versions() -> dict:
|
||||
def get_versions()->dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
|
||||
def welcome(versions: dict):
|
||||
|
||||
@group()
|
||||
def text():
|
||||
yield f"InvokeAI Version: [bold yellow]{__version__}"
|
||||
yield ""
|
||||
yield "This script will update InvokeAI to the latest release, or to a development version of your choice."
|
||||
yield ""
|
||||
yield "[bold yellow]Options:"
|
||||
yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
yield f'InvokeAI Version: [bold yellow]{__version__}'
|
||||
yield ''
|
||||
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
|
||||
yield ''
|
||||
yield '[bold yellow]Options:'
|
||||
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||
[3] Manually enter the tag or branch name you wish to update"""
|
||||
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
|
||||
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to'''
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@@ -59,33 +60,41 @@ def welcome(versions: dict):
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
welcome(versions)
|
||||
|
||||
tag = None
|
||||
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
|
||||
branch = None
|
||||
release = None
|
||||
choice = Prompt.ask('Choice:',choices=['1','2','3','4'],default='1')
|
||||
|
||||
if choice=='1':
|
||||
release = versions[0]['tag_name']
|
||||
elif choice=='2':
|
||||
release = 'main'
|
||||
elif choice=='3':
|
||||
tag = Prompt.ask('Enter an InvokeAI tag name')
|
||||
elif choice=='4':
|
||||
branch = Prompt.ask('Enter an InvokeAI branch name')
|
||||
|
||||
if choice == "1":
|
||||
tag = versions[0]["tag_name"]
|
||||
elif choice == "2":
|
||||
tag = "main"
|
||||
elif choice == "3":
|
||||
tag = Prompt.ask("Enter an InvokeAI tag or branch name")
|
||||
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{tag}[/yellow]")
|
||||
cmd = f"pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517"
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
print(f":heavy_check_mark: Upgrade successful")
|
||||
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
|
||||
if release:
|
||||
cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade'
|
||||
elif tag:
|
||||
cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade'
|
||||
else:
|
||||
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||
|
||||
|
||||
cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade'
|
||||
print('')
|
||||
print('')
|
||||
if os.system(cmd)==0:
|
||||
print(f':heavy_check_mark: Upgrade successful')
|
||||
else:
|
||||
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals, global_config_dir
|
||||
|
||||
from ...backend.config.model_install_backend import (
|
||||
@@ -455,8 +456,8 @@ def main():
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
if not global_config_dir().exists():
|
||||
print(
|
||||
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
logger.info(
|
||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
from invokeai.frontend.install import invokeai_configure
|
||||
|
||||
@@ -466,18 +467,18 @@ def main():
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
logger.info("Goodbye! Come back soon.")
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
logger.error(
|
||||
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
logger.error(
|
||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from ...backend.globals import (
|
||||
global_models_dir,
|
||||
global_set_root,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
|
||||
@@ -113,7 +115,7 @@ def merge_diffusion_models_and_commit(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
print(f">> Using configured VAE assigned to {models[0]}")
|
||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
@@ -391,10 +393,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
]
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -414,7 +414,7 @@ def run_gui(args: Namespace):
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
@@ -425,8 +425,8 @@ def run_cli(args: Namespace):
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
logger.info(
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
@@ -435,7 +435,7 @@ def run_cli(args: Namespace):
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
@@ -455,17 +455,16 @@ def main():
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
logger.error(
|
||||
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"** Not enough room for the user interface. Try making this window larger."
|
||||
logger.error(
|
||||
"Not enough room for the user interface. Try making this window larger."
|
||||
)
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
print(">> An error occurred:")
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
@@ -20,6 +20,7 @@ import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals, global_set_root
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
@@ -368,14 +369,14 @@ def copy_to_embeddings_folder(args: dict):
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (
|
||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||
).startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
logger.info(f'Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict):
|
||||
@@ -422,10 +423,10 @@ def do_front_end(args: Namespace):
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print("** An exception occurred during training. The exception was:")
|
||||
print(str(e))
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
logger.error(str(e))
|
||||
logger.error("DETAILS:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
@@ -437,21 +438,21 @@ def main():
|
||||
else:
|
||||
do_textual_inversion_training(**vars(args))
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
logger.error(
|
||||
"You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Not enough window space for the interface. Please make your window larger and try again."
|
||||
logger.error(
|
||||
"Not enough window space for the interface. Please make your window larger and try again."
|
||||
)
|
||||
else:
|
||||
print(f"** An error has occurred: {str(e)}")
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
{
|
||||
"plugins": [
|
||||
[
|
||||
"transform-imports",
|
||||
{
|
||||
"lodash": {
|
||||
"transform": "lodash/${member}",
|
||||
"preventFullImport": true
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
6
invokeai/frontend/web/.gitignore
vendored
6
invokeai/frontend/web/.gitignore
vendored
@@ -34,4 +34,8 @@ stats.html
|
||||
!.yarn/plugins
|
||||
!.yarn/releases
|
||||
!.yarn/sdks
|
||||
!.yarn/versions
|
||||
!.yarn/versions
|
||||
|
||||
# Yalc
|
||||
.yalc
|
||||
yalc.lock
|
||||
40
invokeai/frontend/web/config/vite.app.config.ts
Normal file
40
invokeai/frontend/web/config/vite.app.config.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { PluginOption, UserConfig } from 'vite';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
|
||||
export const appConfig: UserConfig = {
|
||||
base: './',
|
||||
plugins: [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
],
|
||||
build: {
|
||||
chunkSizeWarningLimit: 1500,
|
||||
},
|
||||
server: {
|
||||
// Proxy HTTP requests to the flask server
|
||||
proxy: {
|
||||
// Proxy socket.io to the nodes socketio server
|
||||
'/ws/socket.io': {
|
||||
target: 'ws://127.0.0.1:9090',
|
||||
ws: true,
|
||||
},
|
||||
// Proxy openapi schema definiton
|
||||
'/openapi.json': {
|
||||
target: 'http://127.0.0.1:9090/openapi.json',
|
||||
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
// proxy nodes api
|
||||
'/api/v1': {
|
||||
target: 'http://127.0.0.1:9090/api/v1',
|
||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
50
invokeai/frontend/web/config/vite.package.config.ts
Normal file
50
invokeai/frontend/web/config/vite.package.config.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import react from '@vitejs/plugin-react-swc';
|
||||
import path from 'path';
|
||||
import { visualizer } from 'rollup-plugin-visualizer';
|
||||
import { PluginOption, UserConfig } from 'vite';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||
|
||||
export const packageConfig: UserConfig = {
|
||||
base: './',
|
||||
plugins: [
|
||||
react(),
|
||||
eslint(),
|
||||
tsconfigPaths(),
|
||||
visualizer() as unknown as PluginOption,
|
||||
dts({
|
||||
insertTypesEntry: true,
|
||||
}),
|
||||
cssInjectedByJsPlugin(),
|
||||
],
|
||||
build: {
|
||||
cssCodeSplit: true,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
||||
name: 'InvokeAIUI',
|
||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['react', 'react-dom', '@emotion/react'],
|
||||
output: {
|
||||
globals: {
|
||||
react: 'React',
|
||||
'react-dom': 'ReactDOM',
|
||||
'@emotion/react': 'EmotionReact',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
app: path.resolve(__dirname, '../src/app'),
|
||||
assets: path.resolve(__dirname, '../src/assets'),
|
||||
common: path.resolve(__dirname, '../src/common'),
|
||||
features: path.resolve(__dirname, '../src/features'),
|
||||
services: path.resolve(__dirname, '../src/services'),
|
||||
theme: path.resolve(__dirname, '../src/theme'),
|
||||
},
|
||||
},
|
||||
};
|
||||
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
import{j as y,cO as Ie,r as _,cP as bt,q as Lr,cQ as o,cR as b,cS as v,cT as S,cU as Vr,cV as ut,cW as vt,cN as ft,cX as mt,n as gt,cY as ht,E as pt}from"./index-e53e8108.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-5cde7d31.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-e53e8108.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ar.json
vendored
1
invokeai/frontend/web/dist/locales/ar.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "داكن",
|
||||
"lightTheme": "فاتح",
|
||||
"greenTheme": "أخضر",
|
||||
"text2img": "نص إلى صورة",
|
||||
"img2img": "صورة إلى صورة",
|
||||
"unifiedCanvas": "لوحة موحدة",
|
||||
"nodes": "عقد",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/de.json
vendored
1
invokeai/frontend/web/dist/locales/de.json
vendored
@@ -7,7 +7,6 @@
|
||||
"darkTheme": "Dunkel",
|
||||
"lightTheme": "Hell",
|
||||
"greenTheme": "Grün",
|
||||
"text2img": "Text zu Bild",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten",
|
||||
"langGerman": "Deutsch",
|
||||
|
||||
4
invokeai/frontend/web/dist/locales/en.json
vendored
4
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -505,7 +505,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user