mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 00:48:03 -05:00
Compare commits
567 Commits
ryan/flux-
...
v4.2.9.dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
987dce801c | ||
|
|
f0f6813979 | ||
|
|
28b86d3917 | ||
|
|
38e25c45bf | ||
|
|
978a869a70 | ||
|
|
2acc5c8317 | ||
|
|
3fe4b770b8 | ||
|
|
c800f84c50 | ||
|
|
a64dd129f1 | ||
|
|
cd4cc56add | ||
|
|
c19aa0389a | ||
|
|
fdb125b294 | ||
|
|
686856d111 | ||
|
|
fec61ce1ff | ||
|
|
bd5780de4d | ||
|
|
db4782a632 | ||
|
|
c4b2099c4f | ||
|
|
ec745f663d | ||
|
|
4f47577076 | ||
|
|
6ee8e13632 | ||
|
|
d469b1771e | ||
|
|
e79f9782ab | ||
|
|
ded72e0362 | ||
|
|
37ed567aad | ||
|
|
9692a5f996 | ||
|
|
5db7b48cd8 | ||
|
|
ea014c66ac | ||
|
|
25918c28aa | ||
|
|
0c60469401 | ||
|
|
f1aa50f447 | ||
|
|
a413b261f0 | ||
|
|
4a1a6639f6 | ||
|
|
201c370ca1 | ||
|
|
d070c7c726 | ||
|
|
e38e20a992 | ||
|
|
39a94ec70e | ||
|
|
c7bfae2d1e | ||
|
|
e7944c427d | ||
|
|
48ed4e120d | ||
|
|
a5b038a1b1 | ||
|
|
dc752c98b0 | ||
|
|
85a47cc6fe | ||
|
|
6450f42cfa | ||
|
|
3876f71ff4 | ||
|
|
cf819e8eab | ||
|
|
2217fb8485 | ||
|
|
43652e830a | ||
|
|
a3417bf81d | ||
|
|
06637161e3 | ||
|
|
c4f4b16a36 | ||
|
|
3001718f9f | ||
|
|
ac16fa65a3 | ||
|
|
bc0b5335ff | ||
|
|
e91c7c5a30 | ||
|
|
74791cc490 | ||
|
|
68409c6a0f | ||
|
|
85613b220c | ||
|
|
80085ad854 | ||
|
|
b6bfa65104 | ||
|
|
0bfa033089 | ||
|
|
6f0974b5bc | ||
|
|
c3b53fc4f6 | ||
|
|
8f59a32d81 | ||
|
|
3b4f20f433 | ||
|
|
73da6e9628 | ||
|
|
2fc482141d | ||
|
|
a41ec5f3fc | ||
|
|
c906225d03 | ||
|
|
2985ea3716 | ||
|
|
4f151c6c6f | ||
|
|
c4f5252c1a | ||
|
|
77c13f2cf3 | ||
|
|
3270d36fca | ||
|
|
b6b30ff01f | ||
|
|
aa9bfdff35 | ||
|
|
80308cc3b8 | ||
|
|
f6db73bf1f | ||
|
|
ef9f61a39f | ||
|
|
a1a0881133 | ||
|
|
9956919ab6 | ||
|
|
abc07f57d6 | ||
|
|
1a1cae79f1 | ||
|
|
bcfafe7b06 | ||
|
|
34e8ced592 | ||
|
|
1fdada65b6 | ||
|
|
433f3e1971 | ||
|
|
a60e23f825 | ||
|
|
f69de3148e | ||
|
|
cbcd36ef54 | ||
|
|
aa76134340 | ||
|
|
55758acae8 | ||
|
|
196e43b5e5 | ||
|
|
38b9828441 | ||
|
|
0048a7077e | ||
|
|
527a39a3ad | ||
|
|
30ce4c55c7 | ||
|
|
ca082d4288 | ||
|
|
5e59a4f43a | ||
|
|
9f86605049 | ||
|
|
79058a7894 | ||
|
|
bb3ad8c2f1 | ||
|
|
799688514b | ||
|
|
b7344b0df2 | ||
|
|
7e382c5f3f | ||
|
|
9cf357e184 | ||
|
|
95b6c773d4 | ||
|
|
89d8c5ba00 | ||
|
|
59580cf6ed | ||
|
|
2b0c084f5b | ||
|
|
4d896073ff | ||
|
|
9f69503a80 | ||
|
|
0311e852a0 | ||
|
|
7003a3d546 | ||
|
|
dc73072e27 | ||
|
|
e549c44ad7 | ||
|
|
45a4231cbe | ||
|
|
81f046ebac | ||
|
|
6ef6c593c4 | ||
|
|
5b53eefef7 | ||
|
|
9a9919c0af | ||
|
|
10661b33d4 | ||
|
|
52193d604d | ||
|
|
2568441e6a | ||
|
|
1a14860b3b | ||
|
|
9ff7647ec5 | ||
|
|
b49106e8fe | ||
|
|
906d0902a3 | ||
|
|
fbde6f5a7f | ||
|
|
b388268987 | ||
|
|
3b4164bd62 | ||
|
|
b7fc6fe573 | ||
|
|
2954a19d27 | ||
|
|
aa45ce7fbd | ||
|
|
77e5078e4a | ||
|
|
603cc7bf2e | ||
|
|
cd517a102d | ||
|
|
9a442918b5 | ||
|
|
f9c03d85a5 | ||
|
|
10d07c71c4 | ||
|
|
cd05a78219 | ||
|
|
f8ee572abc | ||
|
|
d918654509 | ||
|
|
582e30c542 | ||
|
|
34a6555301 | ||
|
|
fff860090b | ||
|
|
f4971197c1 | ||
|
|
621d5e0462 | ||
|
|
0b68a69a6c | ||
|
|
9a599ce595 | ||
|
|
1467ba276f | ||
|
|
708facf707 | ||
|
|
9c6c6adb1f | ||
|
|
c335b8581c | ||
|
|
f1348e45bd | ||
|
|
ce6cf9b079 | ||
|
|
13ec80736a | ||
|
|
c9690a4b21 | ||
|
|
489e875a6e | ||
|
|
8651396048 | ||
|
|
2bab5a6179 | ||
|
|
006f06b615 | ||
|
|
d603923d1b | ||
|
|
86878e855b | ||
|
|
35de60a8fa | ||
|
|
2c444a1941 | ||
|
|
3dfef01889 | ||
|
|
a845a2daa5 | ||
|
|
df41f4fbce | ||
|
|
76482da6f5 | ||
|
|
8205abbbbf | ||
|
|
926873de26 | ||
|
|
00cb1903ba | ||
|
|
58ba38b9c7 | ||
|
|
2f6a5617f9 | ||
|
|
e0d84743be | ||
|
|
ee7c62acc4 | ||
|
|
daf3e58bd9 | ||
|
|
c5b9209057 | ||
|
|
2a4d6d98e2 | ||
|
|
cfdf59d906 | ||
|
|
f91ce1a47c | ||
|
|
4af2888168 | ||
|
|
8471c6fe86 | ||
|
|
fe65a5a2db | ||
|
|
0df26e967c | ||
|
|
d4822b305e | ||
|
|
8df5447563 | ||
|
|
7b5a43df9b | ||
|
|
61ef630175 | ||
|
|
4eda2ef555 | ||
|
|
57f4489520 | ||
|
|
fb6cf9e3da | ||
|
|
f776326cff | ||
|
|
5be32d5733 | ||
|
|
1f73435241 | ||
|
|
3251a00631 | ||
|
|
49c4ad1dd7 | ||
|
|
5857e95c4a | ||
|
|
85be2532c6 | ||
|
|
8b81a00def | ||
|
|
8544595c27 | ||
|
|
a6a5d1470c | ||
|
|
febcc12ec9 | ||
|
|
ab64078b76 | ||
|
|
0ff3459b07 | ||
|
|
2abd7c9bfe | ||
|
|
8e5330bdc9 | ||
|
|
1ecec4ea3a | ||
|
|
700dbe69f3 | ||
|
|
af7ba3b7e4 | ||
|
|
1f7144d62e | ||
|
|
4e389e415b | ||
|
|
2db29fb6ab | ||
|
|
79653fcff5 | ||
|
|
9e39180fbc | ||
|
|
dc0f832d8f | ||
|
|
0dcd6aa5d9 | ||
|
|
9f4a8f11f8 | ||
|
|
6f9579d6ec | ||
|
|
2b2aabb234 | ||
|
|
a79b9633ab | ||
|
|
7b628c908b | ||
|
|
181703a709 | ||
|
|
c439e3c204 | ||
|
|
11e81eb456 | ||
|
|
3e24bf640e | ||
|
|
a17664fb75 | ||
|
|
5aed23dc91 | ||
|
|
7f389716d0 | ||
|
|
eca13b674a | ||
|
|
7c982a1bdf | ||
|
|
bdfe6870fd | ||
|
|
7eebbc0dd9 | ||
|
|
6ae46d7c8b | ||
|
|
8aae30566e | ||
|
|
4b7c3e221c | ||
|
|
4015795b7f | ||
|
|
132dd61d8d | ||
|
|
6f2b548dd1 | ||
|
|
49dd316f17 | ||
|
|
e0a8bb149d | ||
|
|
020b6db34b | ||
|
|
f6d2f0bf8c | ||
|
|
1280cce803 | ||
|
|
82463d12e2 | ||
|
|
ba6c1b84e4 | ||
|
|
92b6d3198a | ||
|
|
693ae1af50 | ||
|
|
3873a3096c | ||
|
|
4a9f6ab5ef | ||
|
|
2fc29c5125 | ||
|
|
5fbc876cfd | ||
|
|
89b0673ac9 | ||
|
|
3179a16189 | ||
|
|
3ce216b391 | ||
|
|
8a3a94e21a | ||
|
|
f61af188f9 | ||
|
|
73fc52bfed | ||
|
|
4eeff4eef8 | ||
|
|
cff2c43030 | ||
|
|
504b1f2425 | ||
|
|
eff5b56990 | ||
|
|
0e673f1a18 | ||
|
|
4fea22aea4 | ||
|
|
c2478c9ac3 | ||
|
|
ed8243825e | ||
|
|
c7ac2b5278 | ||
|
|
a783003556 | ||
|
|
44ba1c6113 | ||
|
|
a02d67fcc6 | ||
|
|
14c8f7c4f5 | ||
|
|
a487ecb50f | ||
|
|
fc55862823 | ||
|
|
e5400601d6 | ||
|
|
735f9f1483 | ||
|
|
d139db0a0f | ||
|
|
a35bb450b1 | ||
|
|
26c01dfa48 | ||
|
|
2b1839374a | ||
|
|
59f5f18e1d | ||
|
|
e41fcb081c | ||
|
|
3032042b35 | ||
|
|
544db61044 | ||
|
|
67a3aa6dff | ||
|
|
34f4468b20 | ||
|
|
d99ae58001 | ||
|
|
d60ec53762 | ||
|
|
0ece9361d5 | ||
|
|
69987a2f00 | ||
|
|
7346bfccb9 | ||
|
|
b705083ce2 | ||
|
|
4b3c82df6f | ||
|
|
46290205d5 | ||
|
|
7b15585b80 | ||
|
|
cbfeeb9079 | ||
|
|
fdac20b43e | ||
|
|
ae2312104e | ||
|
|
74b6674af6 | ||
|
|
81adce3238 | ||
|
|
48b7f460a8 | ||
|
|
38a8232341 | ||
|
|
39a51c4f4c | ||
|
|
220bfeb37d | ||
|
|
e766279950 | ||
|
|
db82406525 | ||
|
|
2d44f332d9 | ||
|
|
ea1526689a | ||
|
|
200338ed72 | ||
|
|
88003a61bd | ||
|
|
b82089b30b | ||
|
|
efd780d395 | ||
|
|
3f90f783de | ||
|
|
4f5b755117 | ||
|
|
a455f12581 | ||
|
|
de516db383 | ||
|
|
6781575293 | ||
|
|
65b0e40fc8 | ||
|
|
19e78a07b7 | ||
|
|
e3b60dda07 | ||
|
|
a9696d3193 | ||
|
|
336d72873f | ||
|
|
111a380bce | ||
|
|
012a8351af | ||
|
|
deeb80ea9b | ||
|
|
1e97a917d6 | ||
|
|
a15ba925db | ||
|
|
6e5a968aad | ||
|
|
915edaa02f | ||
|
|
f6b0fa7c18 | ||
|
|
3f1fba0f35 | ||
|
|
d9fa85a4c6 | ||
|
|
022bb8649c | ||
|
|
673bc33a01 | ||
|
|
579a64928d | ||
|
|
5316df7d7d | ||
|
|
62fe61dc30 | ||
|
|
ea819a4a2f | ||
|
|
868a25dae2 | ||
|
|
3a25d00cb6 | ||
|
|
6301b74d87 | ||
|
|
b43b90c299 | ||
|
|
0a43444ab3 | ||
|
|
75ea4d2155 | ||
|
|
9d281388e0 | ||
|
|
178e1cc50b | ||
|
|
6eafe53b2d | ||
|
|
7514b2b7d4 | ||
|
|
51dee2dba2 | ||
|
|
a81c3d841c | ||
|
|
154487f966 | ||
|
|
588daafcf5 | ||
|
|
ab00097aed | ||
|
|
5aefae71a2 | ||
|
|
0edd598970 | ||
|
|
0bb485031f | ||
|
|
01ffd86367 | ||
|
|
e2e02f31b6 | ||
|
|
0008617348 | ||
|
|
f8f21c0edd | ||
|
|
010916158b | ||
|
|
df0ba004ca | ||
|
|
038b29e15b | ||
|
|
763ab73923 | ||
|
|
9a060b4437 | ||
|
|
2307a30892 | ||
|
|
2006f84f6e | ||
|
|
23b15fef6a | ||
|
|
41e72f929d | ||
|
|
29fc49bb3b | ||
|
|
6f65b6a40f | ||
|
|
1e9b22e3a4 | ||
|
|
3dc2c723c3 | ||
|
|
6f007cbd48 | ||
|
|
9a402dd10e | ||
|
|
4249d0e13b | ||
|
|
f6050bad67 | ||
|
|
d3a4b7b51b | ||
|
|
3f5f9ac764 | ||
|
|
e7c1299a7f | ||
|
|
56a3918a1e | ||
|
|
dabf7718cf | ||
|
|
ef22c29288 | ||
|
|
74f06074f7 | ||
|
|
b97bf52faa | ||
|
|
0a77f5cec8 | ||
|
|
f8e92f7b73 | ||
|
|
530c6e3a59 | ||
|
|
11b95cfaf4 | ||
|
|
707c005a26 | ||
|
|
19fa8e7e33 | ||
|
|
b252ded366 | ||
|
|
84305d4e73 | ||
|
|
1150b41e14 | ||
|
|
88d8ccb34b | ||
|
|
4111b3f1aa | ||
|
|
36862be2aa | ||
|
|
425665e0d9 | ||
|
|
9abd604f69 | ||
|
|
59bdc288b5 | ||
|
|
eb37d2958e | ||
|
|
2a9738a341 | ||
|
|
6aac1cf33a | ||
|
|
9ca4d072ab | ||
|
|
7aaf14c26b | ||
|
|
cf598ca175 | ||
|
|
a722790afc | ||
|
|
320151a040 | ||
|
|
c090f511c3 | ||
|
|
86dd1475b3 | ||
|
|
0b71ac258c | ||
|
|
54e1eae509 | ||
|
|
bf57b2dc77 | ||
|
|
de3c27b44f | ||
|
|
05717fea93 | ||
|
|
191584d229 | ||
|
|
6069169e6b | ||
|
|
07438587f3 | ||
|
|
913e36d6fd | ||
|
|
139004b976 | ||
|
|
ef4269d585 | ||
|
|
954cb129a4 | ||
|
|
02c4b28de5 | ||
|
|
febea88b58 | ||
|
|
40ccfac514 | ||
|
|
831fb814cc | ||
|
|
bf166fdd61 | ||
|
|
384bde3539 | ||
|
|
6f1d238d0a | ||
|
|
ac524153a7 | ||
|
|
2cad2b15cf | ||
|
|
fd63e202fe | ||
|
|
a0250e47e3 | ||
|
|
7d8ece45bb | ||
|
|
ffb8f053da | ||
|
|
fb46f457f9 | ||
|
|
6d4f4152a7 | ||
|
|
d3d0ac7327 | ||
|
|
f57df46995 | ||
|
|
a747171745 | ||
|
|
e9ae9e80d4 | ||
|
|
3b9a59b98d | ||
|
|
8a381e7f74 | ||
|
|
61513fc800 | ||
|
|
1d26c49e92 | ||
|
|
77be9836d2 | ||
|
|
4427960acb | ||
|
|
84aa4fb7bc | ||
|
|
01df96cbe0 | ||
|
|
1ac0634f57 | ||
|
|
8e7d3634b1 | ||
|
|
fadafe5c77 | ||
|
|
b2ea1f6690 | ||
|
|
0a03c1f882 | ||
|
|
ce8b490ed8 | ||
|
|
86eccba80d | ||
|
|
97453e7c6c | ||
|
|
9e1084b701 | ||
|
|
41aec81f3f | ||
|
|
a5741a0551 | ||
|
|
72f73c231a | ||
|
|
b8fcaa274e | ||
|
|
a33bbf48bb | ||
|
|
98d9490fb9 | ||
|
|
51c643c4f8 | ||
|
|
2d7370ca6c | ||
|
|
0d68141387 | ||
|
|
e88a8c6639 | ||
|
|
2d04bb286e | ||
|
|
6d9ba24c32 | ||
|
|
7645a1c86e | ||
|
|
dc284b9a48 | ||
|
|
8a38332d44 | ||
|
|
2407d7d148 | ||
|
|
23275cf607 | ||
|
|
8a0e02d335 | ||
|
|
913873623c | ||
|
|
5d81e7dd4d | ||
|
|
418650fdf3 | ||
|
|
7c24e56d9f | ||
|
|
f6faed46c3 | ||
|
|
1d393eecf1 | ||
|
|
03a72240c0 | ||
|
|
acf62450fb | ||
|
|
d0269310cf | ||
|
|
43618a74e7 | ||
|
|
56fed637ec | ||
|
|
944ae4a604 | ||
|
|
f3da609102 | ||
|
|
11c8a8cf72 | ||
|
|
656978676f | ||
|
|
d216e6519f | ||
|
|
0a2bcae0e3 | ||
|
|
31cf244420 | ||
|
|
d761effac1 | ||
|
|
5efaf2c661 | ||
|
|
b79161fbec | ||
|
|
89a2f2134b | ||
|
|
343c3b19b1 | ||
|
|
e6ec646b2c | ||
|
|
b557fe4e40 | ||
|
|
45908dfbd2 | ||
|
|
1cf0673a22 | ||
|
|
189847f8a5 | ||
|
|
b27929f702 | ||
|
|
b1a6a9835d | ||
|
|
5564a16d4b | ||
|
|
b2aa447d50 | ||
|
|
8666f9a848 | ||
|
|
e7dc7c4917 | ||
|
|
513d0f1e5c | ||
|
|
c6ce7618cf | ||
|
|
9b78b8dc91 | ||
|
|
3bae233f40 | ||
|
|
0a305c4291 | ||
|
|
54fe4ddf3e | ||
|
|
aa877b981c | ||
|
|
b8ec4348a5 | ||
|
|
0e3e27668c | ||
|
|
e8c8025119 | ||
|
|
6a72eda5d2 | ||
|
|
35c8c54466 | ||
|
|
2eda45ca5b | ||
|
|
ecd6e7960c | ||
|
|
0d3a61cbdb | ||
|
|
6737c275d7 | ||
|
|
201a3e5838 | ||
|
|
85cb239219 | ||
|
|
002f45e383 | ||
|
|
470e5ba290 | ||
|
|
b6722b3a10 | ||
|
|
6a624916ca | ||
|
|
65653e0932 | ||
|
|
a7e59d6697 | ||
|
|
f1356105c1 | ||
|
|
8acc6379fb | ||
|
|
d13014c5d9 | ||
|
|
db11d8ba90 | ||
|
|
46a4ee2360 | ||
|
|
a19c053b88 | ||
|
|
5b47a32d31 | ||
|
|
32ae9efb6a | ||
|
|
b81bee551a | ||
|
|
8c9dffd082 | ||
|
|
cdfae643e4 | ||
|
|
cb93108206 | ||
|
|
c11343dc1c | ||
|
|
3733c6f89d | ||
|
|
3c23a0eac0 | ||
|
|
7fd69ab1f1 | ||
|
|
3f511774de | ||
|
|
3600100879 | ||
|
|
8fae372103 | ||
|
|
2fce1fe04c | ||
|
|
22f3e975c7 | ||
|
|
c6ced5a210 | ||
|
|
5d66e85205 | ||
|
|
649f163bf7 | ||
|
|
4c821bd930 | ||
|
|
6359de4e25 | ||
|
|
2c610c8cd4 | ||
|
|
7efe8a249b | ||
|
|
bd421d184e | ||
|
|
78f5844ba0 | ||
|
|
f6bb4d5051 | ||
|
|
56642c3e87 | ||
|
|
9fd8678d3d | ||
|
|
8b89518fd6 |
37
.github/workflows/build-container.yml
vendored
37
.github/workflows/build-container.yml
vendored
@@ -13,12 +13,6 @@ on:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
push-to-registry:
|
||||
description: Push the built image to the container registry
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -56,15 +50,16 @@ jobs:
|
||||
df -h
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
${{ env.DOCKERHUB_REPOSITORY }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
@@ -77,33 +72,49 @@ jobs:
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
with:
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# - name: Login to Docker Hub
|
||||
# if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: docker/login-action@v2
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build container
|
||||
timeout-minutes: 40
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
# - name: Docker Hub Description
|
||||
# if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: peter-evans/dockerhub-description@v3
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# repository: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
# short-description: ${{ github.event.repository.description }}
|
||||
|
||||
@@ -196,22 +196,6 @@ tips to reduce the problem:
|
||||
=== "12GB VRAM GPU"
|
||||
|
||||
This should be sufficient to generate larger images up to about 1280x1280.
|
||||
|
||||
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||
|
||||
The difference between diffusers models (a folder containing multiple
|
||||
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||
incrementally, while checkpoint models must be loaded all at
|
||||
once. With very large models, or systems with limited RAM, you may
|
||||
experience slowdowns and other memory-related issues when loading
|
||||
checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
## Memory Leak (Linux)
|
||||
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
@@ -19,7 +17,6 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
@@ -34,7 +31,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
@@ -54,13 +50,6 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
"""Cache type - one of vram or ram."""
|
||||
|
||||
RAM = "RAM"
|
||||
VRAM = "VRAM"
|
||||
|
||||
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
@@ -808,83 +797,3 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/model_cache",
|
||||
operation_id="get_cache_size",
|
||||
response_model=float,
|
||||
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||
)
|
||||
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
value = 0.0
|
||||
if cache_type == CacheType.RAM:
|
||||
value = cache.max_cache_size
|
||||
elif cache_type == CacheType.VRAM:
|
||||
value = cache.max_vram_cache_size
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/model_cache",
|
||||
operation_id="set_cache_size",
|
||||
response_model=float,
|
||||
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||
)
|
||||
async def set_cache_size(
|
||||
value: float = Query(description="The new value for the maximum cache size"),
|
||||
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||
) -> float:
|
||||
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
app_config = get_config()
|
||||
# Record initial state.
|
||||
vram_old = app_config.vram
|
||||
ram_old = app_config.ram
|
||||
|
||||
# Prepare target state.
|
||||
vram_new = vram_old
|
||||
ram_new = ram_old
|
||||
if cache_type == CacheType.RAM:
|
||||
ram_new = value
|
||||
elif cache_type == CacheType.VRAM:
|
||||
vram_new = value
|
||||
else:
|
||||
raise ValueError(f"Unexpected {cache_type=}.")
|
||||
|
||||
config_path = app_config.config_file_path
|
||||
new_config_path = config_path.with_suffix(".yaml.new")
|
||||
|
||||
try:
|
||||
# Try to apply the target state.
|
||||
cache.max_vram_cache_size = vram_new
|
||||
cache.max_cache_size = ram_new
|
||||
app_config.ram = ram_new
|
||||
app_config.vram = vram_new
|
||||
if persist:
|
||||
app_config.write_file(new_config_path)
|
||||
shutil.move(new_config_path, config_path)
|
||||
except Exception as e:
|
||||
# If there was a failure, restore the initial state.
|
||||
cache.max_cache_size = ram_old
|
||||
cache.max_vram_cache_size = vram_old
|
||||
app_config.ram = ram_old
|
||||
app_config.vram = vram_old
|
||||
|
||||
raise RuntimeError("Failed to update cache size") from e
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/stats",
|
||||
operation_id="get_stats",
|
||||
response_model=Optional[CacheStats],
|
||||
summary="Get model manager RAM cache performance statistics.",
|
||||
)
|
||||
async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByOriginResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
PruneResult,
|
||||
@@ -107,18 +107,16 @@ async def cancel_by_batch_ids(
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
"/{queue_id}/cancel_by_origin",
|
||||
operation_id="cancel_by_origin",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
async def cancel_by_origin(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
origin: str = Query(description="The origin to cancel all queue items for"),
|
||||
) -> CancelByOriginResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
|
||||
@@ -19,8 +19,7 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -83,10 +82,9 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder,
|
||||
loras=_lora_loader(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -179,9 +177,9 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoraPatcher.apply_lora_patches(
|
||||
ModelPatcher.apply_lora(
|
||||
text_encoder,
|
||||
patches=_lora_loader(),
|
||||
loras=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
|
||||
@@ -36,8 +36,7 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
@@ -186,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
description=FieldDescriptions.mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
@@ -980,10 +979,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
loras=_lora_loader(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
||||
@@ -45,13 +45,11 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -130,7 +128,6 @@ class FieldDescriptions:
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
@@ -181,7 +178,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule,
|
||||
generate_img_ids,
|
||||
get_noise,
|
||||
get_schedule,
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
noise = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# If init_latents is provided, we are doing image-to-image.
|
||||
|
||||
if is_schnell:
|
||||
context.logger.warning(
|
||||
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
# Apply the LoRA after transformer has been moved to its target device for faster patching.
|
||||
LoraPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix="",
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
@@ -1,53 +0,0 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX LoRA Loader Output"""
|
||||
|
||||
transformer: TransformerField = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
transformer = self.transformer.model_copy(deep=True)
|
||||
transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return FluxLoRALoaderOutput(transformer=transformer)
|
||||
@@ -40,10 +40,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
@@ -51,7 +48,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load T5.
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
@@ -68,15 +70,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
@@ -88,5 +81,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
172
invokeai/app/invocations/flux_text_to_image.py
Normal file
172
invokeai/app/invocations/flux_text_to_image.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
title="FLUX Text to Image",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_embeddings: torch.Tensor,
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
def step_callback() -> None:
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: Make this look like the image before re-enabling
|
||||
# latent_image = unpack(img.float(), self.height, self.width)
|
||||
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||
|
||||
# # Create a new tensor of the required shape [255, 255, 3]
|
||||
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||
|
||||
# # Convert to a NumPy array and then to a PIL Image
|
||||
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||
|
||||
# (width, height) = image.size
|
||||
# width *= 8
|
||||
# height *= 8
|
||||
|
||||
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||
# context._services.events.emit_invocation_denoise_progress(
|
||||
# context._data.queue_item,
|
||||
# context._data.invocation,
|
||||
# state,
|
||||
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
# )
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _run_vae_decoding(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
latents: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
return img_pil
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,67 +0,0 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="FLUX Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.1.0",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
@@ -135,11 +135,6 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Squeeze the channel dimension if it exists.
|
||||
if mask.dim() == 3:
|
||||
mask = mask.squeeze(0)
|
||||
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
|
||||
@@ -69,7 +69,6 @@ class CLIPField(BaseModel):
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
@@ -158,7 +157,7 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
version="1.0.3",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
@@ -170,46 +169,80 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
t5_encoder: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
model_key = self.model.key
|
||||
|
||||
if not context.models.exists(model_key):
|
||||
raise ValueError(f"Unknown model: {model_key}")
|
||||
transformer = self._get_model(context, SubModelType.Transformer)
|
||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||
vae = self._get_model(context, SubModelType.VAE)
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
match submodel:
|
||||
case SubModelType.Transformer:
|
||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
case SubModelType.VAE:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
SubModelType.VAE,
|
||||
"FLUX.1-schnell_ae",
|
||||
ModelType.VAE,
|
||||
BaseModelType.Flux,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
"clip-vit-large-patch14",
|
||||
ModelType.CLIPEmbed,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||
return self._pull_model_from_mm(
|
||||
context,
|
||||
submodel,
|
||||
self.t5_encoder.name,
|
||||
ModelType.T5Encoder,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||
|
||||
def _pull_model_from_mm(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
submodel: SubModelType,
|
||||
name: str,
|
||||
type: ModelType,
|
||||
base: BaseModelType,
|
||||
):
|
||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||
if len(models) != 1:
|
||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||
else:
|
||||
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@@ -204,11 +204,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoraPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
):
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
|
||||
@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f"::{self.subfolder.as_posix()}"
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -97,8 +97,8 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
"""Cancels all queue items with the given batch destination"""
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
"""Cancels all queue items with the given batch origin"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -346,10 +346,10 @@ class CancelByBatchIDsResult(BaseModel):
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByDestinationResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by a destination"""
|
||||
class CancelByOriginResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
pass
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByOriginResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -426,19 +426,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||
try:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND destination == ?
|
||||
AND origin == ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
params = (queue_id, origin)
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -457,14 +457,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
if current_queue_item is not None and current_queue_item.origin == origin:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
return CancelByOriginResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -557,24 +557,6 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def flux_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
|
||||
"""
|
||||
The step callback emits a progress event with the current step, the total number of
|
||||
steps, a preview image, and some other internal metadata.
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
flux_step_callback(
|
||||
context_data=self._data,
|
||||
intermediate_state=intermediate_state,
|
||||
events=self._services.events,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -1,407 +0,0 @@
|
||||
{
|
||||
"name": "FLUX Image to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple image-to-image workflow using a FLUX dev model. ",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "image2image, flux, image-to-image",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "denoising_start"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "flux_vae_encode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"image": {
|
||||
"name": "image",
|
||||
"label": "",
|
||||
"value": {
|
||||
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 732.7680166609682,
|
||||
"y": -24.37398171806909
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0.04
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1182.8836633018684,
|
||||
"y": -251.38882958913183
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (dev variant recommended for Image-to-Image)"
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
|
||||
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
|
||||
"name": "clip-vit-large-patch14",
|
||||
"base": "any",
|
||||
"type": "clip_embed"
|
||||
}
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
|
||||
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
|
||||
"name": "FLUX.1-schnell_ae",
|
||||
"base": "flux",
|
||||
"type": "vae"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 328.1809894659957,
|
||||
"y": -90.2241133566946
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat wearing a birthday hat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 745.8823365057267,
|
||||
"y": -299.60249175851914
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 725.834098928012,
|
||||
"y": 496.2710031089931
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "height",
|
||||
"targetHandle": "height"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "width",
|
||||
"targetHandle": "width"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,35 +1,27 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
|
||||
"version": "1.0.4",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"version": "1.0.0",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"fieldName": "num_steps"
|
||||
},
|
||||
{
|
||||
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"fieldName": "t5_encoder"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
@@ -38,127 +30,12 @@
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1186.1868226120378,
|
||||
"y": -214.9459927686657
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"version": "1.0.3",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
@@ -167,25 +44,31 @@
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": ""
|
||||
"label": "Model (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
|
||||
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
|
||||
"name": "FLUX Dev (Quantized)",
|
||||
"base": "flux",
|
||||
"type": "main"
|
||||
}
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
|
||||
"value": {
|
||||
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
|
||||
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
|
||||
"name": "t5_bnb_int8_quantized_encoder",
|
||||
"base": "any",
|
||||
"type": "t5_encoder"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
"x": 337.09365228062825,
|
||||
"y": 40.63469521079861
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -222,8 +105,8 @@
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 778.4899149328337,
|
||||
"y": -100.36469216659502
|
||||
"x": 824.1970602278849,
|
||||
"y": 146.98251001061735
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -252,75 +135,132 @@
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 800.9667463219505,
|
||||
"y": 285.8297267547506
|
||||
"x": 822.9899179655476,
|
||||
"y": 360.9657214885052
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"type": "flux_text_to_image",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1216.3900791301849,
|
||||
"y": 5.500841807102248
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -38,25 +38,6 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
[0.0342, -0.0681, -0.0427],
|
||||
[-0.0258, 0.0092, 0.0463],
|
||||
[0.0863, 0.0784, 0.0547],
|
||||
[-0.0017, 0.0402, 0.0158],
|
||||
[0.0501, 0.1058, 0.1152],
|
||||
[-0.0209, -0.0218, -0.0329],
|
||||
[-0.0314, 0.0083, 0.0896],
|
||||
[0.0851, 0.0665, -0.0472],
|
||||
[-0.0534, 0.0238, -0.0024],
|
||||
[0.0452, -0.0026, 0.0048],
|
||||
[0.0892, 0.0831, 0.0881],
|
||||
[-0.1117, -0.0304, -0.0789],
|
||||
[0.0027, -0.0479, -0.0043],
|
||||
[-0.1146, -0.0827, -0.0598],
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
@@ -113,32 +94,3 @@ def stable_diffusion_step_callback(
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
|
||||
def flux_step_callback(
|
||||
context_data: "InvocationContextData",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
events: "EventServiceBase",
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
sample = intermediate_state.latents
|
||||
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
|
||||
latent_image = latent_image_perm @ latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
):
|
||||
step = 0
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step,
|
||||
order=1,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
step += 1
|
||||
|
||||
return img
|
||||
@@ -1,35 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with FLUX."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background. In 'packed' format.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||
"""
|
||||
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, timestep: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
@@ -258,17 +258,16 @@ class Decoder(nn.Module):
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, chunk_dim: int = 1):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if sample:
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||
# have to use torch.randn(...) instead.
|
||||
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
@@ -298,21 +297,8 @@ class AutoEncoder(nn.Module):
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
"""Run VAE encoding on input tensor x.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||
Defaults to True.
|
||||
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
|
||||
176
invokeai/backend/flux/sampling.py
Normal file
176
invokeai/backend/flux/sampling.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
dtype = model.txt_in.bias.dtype
|
||||
|
||||
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||
img = img.to(dtype=dtype)
|
||||
img_ids = img_ids.to(dtype=dtype)
|
||||
txt = txt.to(dtype=dtype)
|
||||
txt_ids = txt_ids.to(dtype=dtype)
|
||||
vec = vec.to(dtype=dtype)
|
||||
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
return img, img_ids
|
||||
@@ -1,135 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||
"""Find the last index in timesteps that is >= val.
|
||||
|
||||
We use epsilon-close equality to avoid potential floating point errors.
|
||||
"""
|
||||
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||
assert idx >= 0
|
||||
return idx
|
||||
|
||||
|
||||
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||
"""Clip the timestep schedule to the denoising range.
|
||||
|
||||
Args:
|
||||
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||
|
||||
Returns:
|
||||
list[float]: The clipped timestep schedule.
|
||||
"""
|
||||
assert 0.0 <= denoising_start <= 1.0
|
||||
assert 0.0 <= denoising_end <= 1.0
|
||||
assert denoising_start <= denoising_end
|
||||
|
||||
t_start_val = 1.0 - denoising_start
|
||||
t_end_val = 1.0 - denoising_end
|
||||
|
||||
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||
|
||||
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||
|
||||
return clipped_timesteps
|
||||
|
||||
|
||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings to latent image."""
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattented array of patch embeddings."""
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
672
invokeai/backend/lora.py
Normal file
672
invokeai/backend/lora.py
Normal file
@@ -0,0 +1,672 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
# lora and locon
|
||||
if "lora_up.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_a" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
# norms
|
||||
elif "w_norm" in values:
|
||||
layer = NormLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
|
||||
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
|
||||
|
||||
# Next, check that this is likely a FLUX model by spot-checking a few keys.
|
||||
expected_keys = [
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight",
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight",
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight",
|
||||
]
|
||||
all_expected_keys_present = all(k in state_dict for k in expected_keys)
|
||||
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||
"""
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
num_single_layers = 38
|
||||
# inner_dim = 3072
|
||||
# mlp_ratio = 4.0
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
layers[dst_key] = LoRALayer(
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
},
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
"""
|
||||
# We expect that either all src keys are present or none of them are. Verify this.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
assert all(keys_present) or not any(keys_present)
|
||||
|
||||
# If none of the keys are present, return early.
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||
sub_layers: list[LoRALayerBase] = []
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
sub_layers.append(
|
||||
LoRALayer(
|
||||
values={
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
"alpha": torch.tensor(alpha),
|
||||
},
|
||||
)
|
||||
)
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_2", "time_in.out_layer")
|
||||
|
||||
# time_text_embed.text_embedder -> vector_in.
|
||||
add_lora_layer_if_present("time_text_embed.text_embedder.linear_1", "vector_in.in_layer")
|
||||
add_lora_layer_if_present("time_text_embed.text_embedder.linear_2", "vector_in.out_layer")
|
||||
|
||||
# time_text_embed.guidance_embedder -> guidance_in.
|
||||
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_1", "guidance_in")
|
||||
add_lora_layer_if_present("time_text_embed.guidance_embedder.linear_2", "guidance_in")
|
||||
|
||||
# context_embedder -> txt_in.
|
||||
add_lora_layer_if_present("context_embedder", "txt_in")
|
||||
|
||||
# x_embedder -> img_in.
|
||||
add_lora_layer_if_present("x_embedder", "img_in")
|
||||
|
||||
# Double transformer blocks.
|
||||
for i in range(num_double_layers):
|
||||
# norms.
|
||||
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1.linear", f"double_blocks.{i}.img_mod.lin")
|
||||
add_lora_layer_if_present(f"transformer_blocks.{i}.norm1_context.linear", f"double_blocks.{i}.txt_mod.lin")
|
||||
|
||||
# Q, K, V
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
f"double_blocks.{i}.img_attn.qkv",
|
||||
)
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
f"double_blocks.{i}.txt_attn.qkv",
|
||||
)
|
||||
|
||||
# ff img_mlp
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff.net.0.proj",
|
||||
f"double_blocks.{i}.img_mlp.0",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff.net.2",
|
||||
f"double_blocks.{i}.img_mlp.2",
|
||||
)
|
||||
|
||||
# ff txt_mlp
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
f"double_blocks.{i}.txt_mlp.0",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.ff_context.net.2",
|
||||
f"double_blocks.{i}.txt_mlp.2",
|
||||
)
|
||||
|
||||
# output projections.
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.attn.to_out.0",
|
||||
f"double_blocks.{i}.img_attn.proj",
|
||||
)
|
||||
add_lora_layer_if_present(
|
||||
f"transformer_blocks.{i}.attn.to_add_out",
|
||||
f"double_blocks.{i}.txt_attn.proj",
|
||||
)
|
||||
|
||||
# Single transformer blocks.
|
||||
for i in range(num_single_layers):
|
||||
# norms
|
||||
add_lora_layer_if_present(
|
||||
f"single_transformer_blocks.{i}.norm.linear",
|
||||
f"single_blocks.{i}.modulation.lin",
|
||||
)
|
||||
|
||||
# Q, K, V, mlp
|
||||
add_qkv_lora_layer_if_present(
|
||||
[
|
||||
f"single_transformer_blocks.{i}.attn.to_q",
|
||||
f"single_transformer_blocks.{i}.attn.to_k",
|
||||
f"single_transformer_blocks.{i}.attn.to_v",
|
||||
f"single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
add_lora_layer_if_present(
|
||||
f"single_transformer_blocks.{i}.proj_out",
|
||||
f"single_blocks.{i}.linear2",
|
||||
)
|
||||
|
||||
# Final layer.
|
||||
add_lora_layer_if_present("proj_out", "final_layer.linear")
|
||||
|
||||
# Assert that all keys were processed.
|
||||
assert len(grouped_state_dict) == 0
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Groups the keys in the state dict by layer."""
|
||||
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key in state_dict:
|
||||
# Split the 'lora_A.weight' or 'lora_B.weight' suffix from the layer name.
|
||||
parts = key.rsplit(".", maxsplit=2)
|
||||
layer_name = parts[0]
|
||||
key_name = ".".join(parts[1:])
|
||||
if layer_name not in layer_dict:
|
||||
layer_dict[layer_name] = {}
|
||||
layer_dict[layer_name][key_name] = state_dict[key]
|
||||
return layer_dict
|
||||
@@ -1,80 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
# lora_unet_double_blocks_0_img_attn_proj.alpha
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_down.weight
|
||||
# lora_unet_double_blocks_0_img_attn_proj.lora_up.weight
|
||||
FLUX_KOHYA_KEY_REGEX = (
|
||||
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
)
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Convert the state dict to the InvokeAI format.
|
||||
grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv"
|
||||
"""
|
||||
|
||||
def replace_func(match: re.Match[str]) -> str:
|
||||
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}"
|
||||
if match.group(4):
|
||||
s += f".{match.group(4)}"
|
||||
return s
|
||||
|
||||
converted_dict: dict[str, T] = {}
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(FLUX_KOHYA_KEY_REGEX, k)
|
||||
if match:
|
||||
new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k)
|
||||
converted_dict[new_key] = v
|
||||
else:
|
||||
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
return converted_dict
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
@@ -1,154 +0,0 @@
|
||||
import bisect
|
||||
from typing import Dict, List, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def _make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in _make_sdxl_unet_conversion_map()
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
@@ -1,46 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
"""A LoRA layer that is composed of multiple LoRA layers concatenated along a specified axis.
|
||||
|
||||
This class was created to handle a special case with FLUX LoRA models. In the BFL FLUX model format, the attention
|
||||
Q, K, V matrices are concatenated along the first dimension. In the diffusers LoRA format, the Q, K, V matrices are
|
||||
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_layers: List[LoRALayerBase], concat_axis: int = 0):
|
||||
# Note: We pass values={} to the base class, because the values are handled by the individual LoRA layers.
|
||||
super().__init__(values={})
|
||||
|
||||
self._lora_layers = lora_layers
|
||||
self._concat_axis = concat_axis
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Currently, we pass orig_weight=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original weight tensor here.
|
||||
layer_weights = [lora_layer.get_weight(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||
return torch.cat(layer_weights, dim=self._concat_axis)
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that
|
||||
# require this value, we will need to implement chunking of the original bias tensor here.
|
||||
layer_biases = [lora_layer.get_bias(None) for lora_layer in self._lora_layers] # pyright: ignore[reportArgumentType]
|
||||
layer_bias_is_none = [layer_bias is None for layer_bias in layer_biases]
|
||||
if any(layer_bias_is_none):
|
||||
assert all(layer_bias_is_none)
|
||||
return None
|
||||
|
||||
# Ignore the type error, because we have just verified that all layer biases are non-None.
|
||||
return torch.cat(layer_biases, dim=self._concat_axis)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return sum(lora_layer.calc_size() for lora_layer in self._lora_layers)
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
self.bias = values.get("diff_b", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"diff", "diff_b"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@@ -1,41 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"weight", "on_input"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
@@ -1,68 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"hada_w1_a",
|
||||
"hada_w1_b",
|
||||
"hada_w2_a",
|
||||
"hada_w2_b",
|
||||
"hada_t1",
|
||||
"hada_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
@@ -1,113 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
else:
|
||||
self.w1_b = None
|
||||
self.w1_a = None
|
||||
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
else:
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lokr_w1",
|
||||
"lokr_w1_a",
|
||||
"lokr_w1_b",
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
@@ -1,58 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
"lora_up.weight",
|
||||
"lora_down.weight",
|
||||
"lora_mid.weight",
|
||||
},
|
||||
)
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
logger.warning(
|
||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class NormLayer(LoRALayerBase):
|
||||
# bias handled in LoRALayerBase(calc_size, to)
|
||||
# weight: torch.Tensor
|
||||
# bias: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(values)
|
||||
|
||||
self.weight = values["w_norm"]
|
||||
self.bias = values.get("b_norm", None)
|
||||
|
||||
self.rank = None # unscaled
|
||||
self.check_keys(values, {"w_norm", "b_norm"})
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@@ -1,33 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
if "lora_up.weight" in state_dict:
|
||||
# LoRA a.k.a LoCon
|
||||
return LoRALayer(state_dict)
|
||||
elif "hada_w1_a" in state_dict:
|
||||
return LoHALayer(state_dict)
|
||||
elif "lokr_w1" in state_dict or "lokr_w1_a" in state_dict:
|
||||
return LoKRLayer(state_dict)
|
||||
elif "diff" in state_dict:
|
||||
# Full a.k.a Diff
|
||||
return FullLayer(state_dict)
|
||||
elif "on_input" in state_dict:
|
||||
return IA3Layer(state_dict)
|
||||
elif "w_norm" in state_dict:
|
||||
return NormLayer(state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported lora format: {state_dict.keys()}")
|
||||
@@ -1,22 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def __init__(self, layers: Dict[str, AnyLoRALayer]):
|
||||
self.layers = layers
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
@@ -1,148 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoraPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply one or more LoRA patches to a model within a context manager.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns tuples of LoRA patches and associated weights. An iterator is used so
|
||||
that the LoRA patches do not need to be loaded into memory all at once.
|
||||
:param prefix: The keys in the patches will be filtered to only include weights with this prefix.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for efficient unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoraPatcher.apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del patch
|
||||
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply a single LoRA patch to a model.
|
||||
:param model: The model to patch.
|
||||
:param patch: LoRA model to patch in.
|
||||
:param patch_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoraPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _get_submodule(
|
||||
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
|
||||
) -> tuple[str, torch.nn.Module]:
|
||||
"""Get the submodule corresponding to the given layer key.
|
||||
:param model: The model to search.
|
||||
:param layer_key: The layer key to search for.
|
||||
:param layer_key_is_flattened: Whether the layer key is flattened. If flattened, then all '.' have been replaced
|
||||
with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly without
|
||||
searching, but some legacy code still uses flattened keys.
|
||||
:return: A tuple containing the module key and the submodule.
|
||||
"""
|
||||
if not layer_key_is_flattened:
|
||||
return layer_key, model.get_submodule(layer_key)
|
||||
|
||||
# Handle flattened keys.
|
||||
assert "." not in layer_key
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = layer_key.split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return module_key, module
|
||||
@@ -66,14 +66,12 @@ class ModelLoader(ModelLoaderBase):
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
@@ -85,7 +83,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
|
||||
@@ -128,24 +128,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -210,6 +193,15 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
@@ -24,74 +40,53 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
"""Implementation of ModelCacheBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@@ -133,16 +128,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -160,6 +145,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@@ -209,7 +203,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@@ -237,13 +231,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@@ -254,7 +245,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@@ -312,7 +303,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@@ -335,14 +326,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@@ -362,20 +353,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@@ -392,7 +380,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
||||
@@ -32,9 +32,6 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
@@ -193,13 +190,6 @@ class FluxCheckpointModel(ModelLoader):
|
||||
with SilenceWarnings():
|
||||
model = Flux(params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
self._ram_cache.make_room(new_sd_size)
|
||||
for k in sd.keys():
|
||||
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
|
||||
sd[k] = sd[k].to(torch.bfloat16)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -240,7 +230,5 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
model = Flux(params[config.config_path])
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -5,18 +5,8 @@ from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -55,33 +45,14 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
|
||||
# Load the state dict from the model file.
|
||||
if model_path.suffix == ".safetensors":
|
||||
state_dict = load_file(model_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||
|
||||
model.to(dtype=self._torch_dtype)
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
||||
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
self._model_base = config.base
|
||||
|
||||
@@ -15,7 +15,7 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
@@ -10,10 +10,6 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@@ -112,8 +108,6 @@ class ModelProbe(object):
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -230,27 +224,14 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
"cond_stage_model.",
|
||||
"first_stage_model.",
|
||||
"model.diffusion_model.",
|
||||
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
|
||||
"double_blocks.",
|
||||
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
|
||||
# This prefix is typically used to distinguish between multiple models bundled in a single file.
|
||||
"model.diffusion_model.double_blocks.",
|
||||
)
|
||||
):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
return ModelType.LoRA
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
@@ -302,16 +283,9 @@ class ModelProbe(object):
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
config_path = None
|
||||
for p in [
|
||||
folder_path / "model_index.json", # pipeline
|
||||
folder_path / "config.json", # most diffusers
|
||||
folder_path / "text_encoder_2" / "config.json", # T5 text encoder
|
||||
folder_path / "text_encoder" / "config.json", # T5 CLIP
|
||||
]:
|
||||
if p.exists():
|
||||
config_path = p
|
||||
break
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
@@ -354,10 +328,7 @@ class ModelProbe(object):
|
||||
# TODO: Decide between dev/schnell
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if (
|
||||
"guidance_in.out_layer.weight" in state_dict
|
||||
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
||||
):
|
||||
if "guidance_in.out_layer.weight" in state_dict:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
@@ -365,7 +336,7 @@ class ModelProbe(object):
|
||||
config_file = "flux-dev"
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the discriminator for model configs this
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
@@ -472,10 +443,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
if (
|
||||
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
@@ -502,10 +470,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
if (
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||
):
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
return BaseModelType.Flux
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
@@ -560,21 +525,12 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for LoRA checkpoints."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint):
|
||||
# TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat
|
||||
# ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single
|
||||
# file, but the weight keys are in the diffusers format.
|
||||
return ModelFormat.Diffusers
|
||||
return ModelFormat.LyCORIS
|
||||
return ModelFormat("lycoris")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model.
|
||||
token_vector_length = lora_token_vector_length(self.checkpoint)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
@@ -791,27 +747,8 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
path = self.model_path / "text_encoder_2"
|
||||
if (path / "model.safetensors.index.json").exists():
|
||||
return ModelFormat.T5Encoder
|
||||
files = list(path.glob("*.safetensors"))
|
||||
if len(files) == 0:
|
||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found")
|
||||
|
||||
# shortcut: look for the quantization in the name
|
||||
if any(x for x in files if "llm_int8" in x.as_posix()):
|
||||
return ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
# more reliable path: probe contents for a 'SCB' key
|
||||
ckpt = read_checkpoint_meta(files[0], scan=True)
|
||||
if any("SCB" in x for x in ckpt.keys()):
|
||||
return ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format")
|
||||
return ModelFormat.T5Encoder
|
||||
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
|
||||
@@ -133,29 +133,3 @@ def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[in
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
|
||||
def convert_bundle_to_flux_transformer_checkpoint(
|
||||
transformer_state_dict: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
original_state_dict: dict[str, torch.Tensor] = {}
|
||||
keys_to_remove: list[str] = []
|
||||
|
||||
for k, v in transformer_state_dict.items():
|
||||
if not k.startswith("model.diffusion_model"):
|
||||
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
|
||||
continue
|
||||
if k.endswith("scale"):
|
||||
# Scale math must be done at bfloat16 due to our current flux model
|
||||
# support limitations at inference time
|
||||
v = v.to(dtype=torch.bfloat16)
|
||||
new_key = k.replace("model.diffusion_model.", "")
|
||||
original_state_dict[new_key] = v
|
||||
keys_to_remove.append(k)
|
||||
|
||||
# Remove processed keys from the original dictionary, leaving others in case
|
||||
# other model state dicts need to be pulled
|
||||
for k in keys_to_remove:
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
@@ -5,18 +5,32 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
@@ -40,6 +54,95 @@ class ModelPatcher:
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(
|
||||
unet,
|
||||
loras=loras,
|
||||
prefix="lora_unet_",
|
||||
cached_weights=cached_weights,
|
||||
):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
|
||||
:param model: The model to patch.
|
||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for lora_model, lora_weight in loras:
|
||||
LoRAExt.patch_model(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
lora=lora_model,
|
||||
lora_weight=lora_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
@@ -179,6 +282,26 @@ class ModelPatcher:
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
|
||||
@@ -54,10 +54,8 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
@@ -91,14 +89,6 @@ class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||
new_state = bnb.MatmulLtState()
|
||||
new_state.threshold = self.state.threshold
|
||||
new_state.has_fp16_weights = False
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
|
||||
@@ -43,11 +43,6 @@ class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoraPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
@@ -30,14 +31,107 @@ class LoRAExt(ExtensionBase):
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoraPatcher.apply_lora_patch(
|
||||
self.patch_model(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
patch_weight=self._weight,
|
||||
lora=lora_model,
|
||||
lora_weight=self._weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
del lora_model
|
||||
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def patch_model(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
:param model: The model to patch.
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if lora_weight == 0:
|
||||
return
|
||||
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
|
||||
lora_param_weight *= lora_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@@ -3,9 +3,10 @@ Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import Chdir, directory_size
|
||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
|
||||
@@ -7,6 +7,9 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
|
||||
@@ -16,14 +16,6 @@ module.exports = {
|
||||
'no-promise-executor-return': 'error',
|
||||
// https://eslint.org/docs/latest/rules/require-await
|
||||
'require-await': 'error',
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
object: 'crypto',
|
||||
property: 'randomUUID',
|
||||
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||
},
|
||||
],
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
|
||||
@@ -11,8 +11,6 @@ const config: KnipConfig = {
|
||||
'src/features/nodes/types/v2/**',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// TODO(psyche): restore HRF functionality?
|
||||
'src/features/hrf/**',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
|
||||
@@ -58,7 +58,7 @@
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.0.20",
|
||||
"@invoke-ai/ui-library": "^0.0.33",
|
||||
"@invoke-ai/ui-library": "^0.0.32",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
@@ -92,7 +92,7 @@
|
||||
"react-i18next": "^14.1.3",
|
||||
"react-icons": "^5.2.1",
|
||||
"react-redux": "9.1.2",
|
||||
"react-resizable-panels": "^2.1.2",
|
||||
"react-resizable-panels": "^2.0.23",
|
||||
"react-use": "^17.5.1",
|
||||
"react-virtuoso": "^4.9.0",
|
||||
"reactflow": "^11.11.4",
|
||||
@@ -136,7 +136,6 @@
|
||||
"@vitest/coverage-v8": "^1.5.0",
|
||||
"@vitest/ui": "^1.5.0",
|
||||
"concurrently": "^8.2.2",
|
||||
"csstype": "^3.1.3",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-plugin-i18next": "^6.0.9",
|
||||
|
||||
65
invokeai/frontend/web/pnpm-lock.yaml
generated
65
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -24,8 +24,8 @@ dependencies:
|
||||
specifier: ^5.0.20
|
||||
version: 5.0.20
|
||||
'@invoke-ai/ui-library':
|
||||
specifier: ^0.0.33
|
||||
version: 0.0.33(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^0.0.32
|
||||
version: 0.0.32(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@nanostores/react':
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3(nanostores@0.11.2)(react@18.3.1)
|
||||
@@ -126,8 +126,8 @@ dependencies:
|
||||
specifier: 9.1.2
|
||||
version: 9.1.2(@types/react@18.3.3)(react@18.3.1)(redux@5.0.1)
|
||||
react-resizable-panels:
|
||||
specifier: ^2.1.2
|
||||
version: 2.1.2(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^2.0.23
|
||||
version: 2.0.23(react-dom@18.3.1)(react@18.3.1)
|
||||
react-use:
|
||||
specifier: ^17.5.1
|
||||
version: 17.5.1(react-dom@18.3.1)(react@18.3.1)
|
||||
@@ -238,9 +238,6 @@ devDependencies:
|
||||
concurrently:
|
||||
specifier: ^8.2.2
|
||||
version: 8.2.2
|
||||
csstype:
|
||||
specifier: ^3.1.3
|
||||
version: 3.1.3
|
||||
dpdm:
|
||||
specifier: ^3.14.0
|
||||
version: 3.14.0
|
||||
@@ -2056,7 +2053,7 @@ packages:
|
||||
dependencies:
|
||||
'@chakra-ui/dom-utils': 2.1.0
|
||||
react: 18.3.1
|
||||
react-focus-lock: 2.13.2(@types/react@18.3.3)(react@18.3.1)
|
||||
react-focus-lock: 2.13.0(@types/react@18.3.3)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
@@ -2253,7 +2250,7 @@ packages:
|
||||
framer-motion: 10.18.0(react-dom@18.3.1)(react@18.3.1)
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
react-remove-scroll: 2.6.0(@types/react@18.3.3)(react@18.3.1)
|
||||
react-remove-scroll: 2.5.10(@types/react@18.3.3)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
@@ -3574,8 +3571,8 @@ packages:
|
||||
prettier: 3.3.3
|
||||
dev: true
|
||||
|
||||
/@invoke-ai/ui-library@0.0.33(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-YLydTCOTUEgju4Ex6yXt/bvNBcO97y6zc1cGYjt7vtJMS8e6deA89cC5JejjbmVgntdnn49cDyeUxB8Z24gZew==}
|
||||
/@invoke-ai/ui-library@0.0.32(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@fontsource-variable/inter@5.0.20)(@types/react@18.3.3)(i18next@23.12.2)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-JxAoblrDu/cZ4ha9KO4ry5OWvyLUE1Dj28i+ciMaDNUpC/cN+IyiTbUBoFoPaoN5JP8Zpd/MYCcmF2qsziHDzg==}
|
||||
peerDependencies:
|
||||
'@fontsource-variable/inter': ^5.0.16
|
||||
react: ^18.2.0
|
||||
@@ -10011,8 +10008,8 @@ packages:
|
||||
resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==}
|
||||
dev: false
|
||||
|
||||
/react-focus-lock@2.13.2(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-T/7bsofxYqnod2xadvuwjGKHOoL5GH7/EIPI5UyEvaU/c2CcphvGI371opFtuY/SYdbMsNiuF4HsHQ50nA/TKQ==}
|
||||
/react-focus-lock@2.13.0(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-w7aIcTwZwNzUp2fYQDMICy+khFwVmKmOrLF8kNsPS+dz4Oi/oxoVJ2wCMVvX6rWGriM/+mYaTyp1MRmkcs2amw==}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
@@ -10147,6 +10144,25 @@ packages:
|
||||
tslib: 2.7.0
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.5.10(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-m3zvBRANPBw3qxVVjEIPEQinkcwlFZ4qyomuWVpNJdv4c6MvHfXV0C3L9Jx5rr3HeBHKNRX+1jreB5QloDIJjA==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
react-remove-scroll-bar: 2.3.6(@types/react@18.3.3)(react@18.3.1)
|
||||
react-style-singleton: 2.2.1(@types/react@18.3.3)(react@18.3.1)
|
||||
tslib: 2.7.0
|
||||
use-callback-ref: 1.3.2(@types/react@18.3.3)(react@18.3.1)
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.5.5(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==}
|
||||
engines: {node: '>=10'}
|
||||
@@ -10166,27 +10182,8 @@ packages:
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.6.0(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-I2U4JVEsQenxDAKaVa3VZ/JeJZe0/2DxPWL8Tj8yLKctQJQiZM52pn/GWFpSp8dftjM3pSAHVJZscAnC/y+ySQ==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
react-remove-scroll-bar: 2.3.6(@types/react@18.3.3)(react@18.3.1)
|
||||
react-style-singleton: 2.2.1(@types/react@18.3.3)(react@18.3.1)
|
||||
tslib: 2.7.0
|
||||
use-callback-ref: 1.3.2(@types/react@18.3.3)(react@18.3.1)
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-resizable-panels@2.1.2(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-Ku2Bo7JvE8RpHhl4X1uhkdeT9auPBoxAOlGTqomDUUrBAX2mVGuHYZTcWvlnJSgx0QyHIxHECgGB5XVPUbUOkQ==}
|
||||
/react-resizable-panels@2.0.23(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-8ZKTwTU11t/FYwiwhMdtZYYyFxic5U5ysRu2YwfkAgDbUJXFvnWSJqhnzkSlW+mnDoNAzDCrJhdOSXBPA76wug==}
|
||||
peerDependencies:
|
||||
react: ^16.14.0 || ^17.0.0 || ^18.0.0
|
||||
react-dom: ^16.14.0 || ^17.0.0 || ^18.0.0
|
||||
|
||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
@@ -127,14 +127,7 @@
|
||||
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
|
||||
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
|
||||
"bulkDownloadFailed": "Download fehlgeschlagen",
|
||||
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen",
|
||||
"selectForCompare": "Zum Vergleichen auswählen",
|
||||
"compareImage": "Bilder vergleichen",
|
||||
"exitSearch": "Suche beenden",
|
||||
"newestFirst": "Neueste zuerst",
|
||||
"oldestFirst": "Älteste zuerst",
|
||||
"openInViewer": "Im Viewer öffnen",
|
||||
"swapImages": "Bilder tauschen"
|
||||
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tastenkürzel",
|
||||
@@ -638,8 +631,7 @@
|
||||
"archived": "Archiviert",
|
||||
"noBoards": "Kein {boardType}} Ordner",
|
||||
"hideBoards": "Ordner verstecken",
|
||||
"viewBoards": "Ordner ansehen",
|
||||
"deletedPrivateBoardsCannotbeRestored": "Gelöschte Boards können nicht wiederhergestellt werden. Wenn Sie „Nur Board löschen“ wählen, werden die Bilder in einen privaten, nicht kategorisierten Status für den Ersteller des Bildes versetzt."
|
||||
"viewBoards": "Ordner ansehen"
|
||||
},
|
||||
"controlnet": {
|
||||
"showAdvanced": "Zeige Erweitert",
|
||||
@@ -789,9 +781,7 @@
|
||||
"batchFieldValues": "Stapelverarbeitungswerte",
|
||||
"batchQueued": "Stapelverarbeitung eingereiht",
|
||||
"graphQueued": "Graph eingereiht",
|
||||
"graphFailedToQueue": "Fehler beim Einreihen des Graphen",
|
||||
"generations_one": "Generation",
|
||||
"generations_other": "Generationen"
|
||||
"graphFailedToQueue": "Fehler beim Einreihen des Graphen"
|
||||
},
|
||||
"metadata": {
|
||||
"negativePrompt": "Negativ Beschreibung",
|
||||
@@ -1156,10 +1146,5 @@
|
||||
"noMatchingTriggers": "Keine passenden Trigger",
|
||||
"addPromptTrigger": "Prompt-Trigger hinzufügen",
|
||||
"compatibleEmbeddings": "Kompatible Einbettungen"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"queue": "Warteschlange"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,6 @@
|
||||
"copy": "Copy",
|
||||
"copyError": "$t(gallery.copy) Error",
|
||||
"on": "On",
|
||||
"off": "Off",
|
||||
"or": "or",
|
||||
"checkpoint": "Checkpoint",
|
||||
"communityLabel": "Community",
|
||||
@@ -135,7 +134,6 @@
|
||||
"nodes": "Workflows",
|
||||
"notInstalled": "Not $t(common.installed)",
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"orderBy": "Order By",
|
||||
"outpaint": "outpaint",
|
||||
"outputs": "Outputs",
|
||||
@@ -375,7 +373,6 @@
|
||||
"useCache": "Use Cache"
|
||||
},
|
||||
"gallery": {
|
||||
"gallery": "Gallery",
|
||||
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
|
||||
"assets": "Assets",
|
||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
|
||||
@@ -388,11 +385,11 @@
|
||||
"deleteImage_one": "Delete Image",
|
||||
"deleteImage_other": "Delete {{count}} Images",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
"displayBoardSearch": "Board Search",
|
||||
"displaySearch": "Image Search",
|
||||
"displayBoardSearch": "Display Board Search",
|
||||
"displaySearch": "Display Search",
|
||||
"download": "Download",
|
||||
"exitBoardSearch": "Exit Board Search",
|
||||
"exitSearch": "Exit Image Search",
|
||||
"exitSearch": "Exit Search",
|
||||
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
|
||||
"galleryImageSize": "Image Size",
|
||||
"gallerySettings": "Gallery Settings",
|
||||
@@ -438,8 +435,7 @@
|
||||
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
|
||||
"toggleMiniViewer": "Toggle Mini Viewer"
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Search Hotkeys",
|
||||
@@ -710,8 +706,6 @@
|
||||
"availableModels": "Available Models",
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipVision": "CLIP Vision",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@@ -799,7 +793,6 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
@@ -808,7 +801,6 @@
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"urlOrLocalPath": "URL or Local Path",
|
||||
@@ -1018,8 +1010,6 @@
|
||||
"noModelForControlAdapter": "Control Adapter #{{number}} has no model selected.",
|
||||
"incompatibleBaseModelForControlAdapter": "Control Adapter #{{number}} model is incompatible with main model.",
|
||||
"noModelSelected": "No model selected",
|
||||
"canvasManagerNotLoaded": "Canvas Manager not loaded",
|
||||
"canvasBusy": "Canvas is busy",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"systemDisconnected": "System disconnected",
|
||||
@@ -1051,11 +1041,12 @@
|
||||
"scaledHeight": "Scaled H",
|
||||
"scaledWidth": "Scaled W",
|
||||
"scheduler": "Scheduler",
|
||||
"seamlessXAxis": "Seamless X Axis",
|
||||
"seamlessYAxis": "Seamless Y Axis",
|
||||
"seamlessXAxis": "Seamless Tiling X Axis",
|
||||
"seamlessYAxis": "Seamless Tiling Y Axis",
|
||||
"seed": "Seed",
|
||||
"imageActions": "Image Actions",
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
"sendToImg2Img": "Send to Image to Image",
|
||||
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
||||
"sendToUpscale": "Send To Upscale",
|
||||
"showOptionsPanel": "Show Side Panel (O or T)",
|
||||
"shuffle": "Shuffle Seed",
|
||||
@@ -1196,8 +1187,8 @@
|
||||
"problemSavingMaskDesc": "Unable to export mask",
|
||||
"prunedQueue": "Pruned Queue",
|
||||
"resetInitialImage": "Reset Initial Image",
|
||||
"sentToCanvas": "Sent to Canvas",
|
||||
"sentToUpscale": "Sent to Upscale",
|
||||
"sentToImageToImage": "Sent To Image To Image",
|
||||
"sentToUnifiedCanvas": "Sent to Unified Canvas",
|
||||
"serverError": "Server Error",
|
||||
"sessionRef": "Session: {{sessionId}}",
|
||||
"setAsCanvasInitialImage": "Set as canvas initial image",
|
||||
@@ -1659,16 +1650,6 @@
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
},
|
||||
"controlLayers": {
|
||||
"bookmark": "Bookmark for Quick Switch",
|
||||
"fitBboxToLayers": "Fit Bbox To Layers",
|
||||
"removeBookmark": "Remove Bookmark",
|
||||
"saveCanvasToGallery": "Save Canvas To Gallery",
|
||||
"saveBboxToGallery": "Save Bbox To Gallery",
|
||||
"savedToGalleryOk": "Saved to Gallery",
|
||||
"savedToGalleryError": "Error saving to gallery",
|
||||
"mergeVisible": "Merge Visible",
|
||||
"mergeVisibleOk": "Merged visible layers",
|
||||
"mergeVisibleError": "Error merging visible layers",
|
||||
"clearHistory": "Clear History",
|
||||
"generateMode": "Generate",
|
||||
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||
@@ -1680,7 +1661,6 @@
|
||||
"clearCaches": "Clear Caches",
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"compositeMaskedRegions": "Composite Masked Regions",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -1716,8 +1696,6 @@
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"sendingToCanvas": "Sending to Canvas",
|
||||
"sendingToGallery": "Sending to Gallery",
|
||||
"sendToGallery": "Send To Gallery",
|
||||
"sendToGalleryDesc": "Generations will be sent to the gallery.",
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
@@ -1736,12 +1714,12 @@
|
||||
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||
"globalIPAdapters_withCount_hidden": "Global IP Adapters ({{count}} hidden)",
|
||||
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
|
||||
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||
"globalIPAdapters_withCount_visible": "Global IP Adapters ({{count}})",
|
||||
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
|
||||
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
|
||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||
@@ -1754,8 +1732,8 @@
|
||||
"clearProcessor": "Clear Processor",
|
||||
"resetProcessor": "Reset Processor to Defaults",
|
||||
"noLayersAdded": "No Layers Added",
|
||||
"layer_one": "Layer",
|
||||
"layer_other": "Layers",
|
||||
"layers_one": "Layer",
|
||||
"layers_other": "Layers",
|
||||
"objects_zero": "empty",
|
||||
"objects_one": "{{count}} object",
|
||||
"objects_other": "{{count}} objects",
|
||||
@@ -1791,54 +1769,16 @@
|
||||
"bbox": "Bbox",
|
||||
"move": "Move",
|
||||
"view": "View",
|
||||
"transform": "Transform",
|
||||
"colorPicker": "Color Picker"
|
||||
},
|
||||
"filter": {
|
||||
"filter": "Filter",
|
||||
"filters": "Filters",
|
||||
"filterType": "Filter Type",
|
||||
"autoProcess": "Auto Process",
|
||||
"reset": "Reset",
|
||||
"process": "Process",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel",
|
||||
"spandrel": {
|
||||
"label": "Image-to-Image Model",
|
||||
"description": "Run an image-to-image model on the selected layer.",
|
||||
"paramModel": "Model",
|
||||
"paramAutoScale": "Auto Scale",
|
||||
"paramAutoScaleDesc": "The selected model will be run until the target scale is reached.",
|
||||
"paramScale": "Target Scale"
|
||||
}
|
||||
},
|
||||
"transform": {
|
||||
"transform": "Transform",
|
||||
"fitToBbox": "Fit to Bbox",
|
||||
"reset": "Reset",
|
||||
"preview": "Preview",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"settings": {
|
||||
"snapToGrid": {
|
||||
"label": "Snap to Grid",
|
||||
"on": "On",
|
||||
"off": "Off"
|
||||
}
|
||||
},
|
||||
"HUD": {
|
||||
"bbox": "Bbox",
|
||||
"scaledBbox": "Scaled Bbox",
|
||||
"autoSave": "Auto Save",
|
||||
"entityStatus": {
|
||||
"selectedEntity": "Selected Entity",
|
||||
"selectedEntityIs": "Selected Entity is",
|
||||
"isFiltering": "is filtering",
|
||||
"isTransforming": "is transforming",
|
||||
"isLocked": "is locked",
|
||||
"isHidden": "is hidden",
|
||||
"isDisabled": "is disabled",
|
||||
"enabled": "Enabled"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
|
||||
@@ -86,15 +86,15 @@
|
||||
"loadMore": "Cargar más",
|
||||
"noImagesInGallery": "No hay imágenes para mostrar",
|
||||
"deleteImage_one": "Eliminar Imagen",
|
||||
"deleteImage_many": "Eliminar {{count}} Imágenes",
|
||||
"deleteImage_other": "Eliminar {{count}} Imágenes",
|
||||
"deleteImage_many": "",
|
||||
"deleteImage_other": "",
|
||||
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
||||
"assets": "Activos",
|
||||
"autoAssignBoardOnClick": "Asignación automática de tableros al hacer clic"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Atajos de teclado",
|
||||
"appHotkeys": "Atajos de aplicación",
|
||||
"appHotkeys": "Atajos de applicación",
|
||||
"generalHotkeys": "Atajos generales",
|
||||
"galleryHotkeys": "Atajos de galería",
|
||||
"unifiedCanvasHotkeys": "Atajos de lienzo unificado",
|
||||
@@ -535,7 +535,7 @@
|
||||
"bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.",
|
||||
"deleteBoardAndImages": "Borrar el panel y las imágenes",
|
||||
"loading": "Cargando...",
|
||||
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar. Al Seleccionar 'Borrar Solo el Panel' transferirá las imágenes a un estado sin categorizar.",
|
||||
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar",
|
||||
"move": "Mover",
|
||||
"menuItemAutoAdd": "Agregar automáticamente a este panel",
|
||||
"searchBoard": "Buscando paneles…",
|
||||
@@ -549,13 +549,7 @@
|
||||
"imagesWithCount_other": "{{count}} imágenes",
|
||||
"assetsWithCount_one": "{{count}} activo",
|
||||
"assetsWithCount_many": "{{count}} activos",
|
||||
"assetsWithCount_other": "{{count}} activos",
|
||||
"hideBoards": "Ocultar Paneles",
|
||||
"addPrivateBoard": "Agregar un tablero privado",
|
||||
"addSharedBoard": "Agregar Panel Compartido",
|
||||
"boards": "Paneles",
|
||||
"archiveBoard": "Archivar Panel",
|
||||
"archived": "Archivado"
|
||||
"assetsWithCount_other": "{{count}} activos"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
|
||||
@@ -496,9 +496,7 @@
|
||||
"main": "Principali",
|
||||
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||
"ipAdapters": "Adattatori IP",
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
"starterModelsInModelManager": "I modelli iniziali possono essere trovati in Gestione Modelli",
|
||||
"spandrelImageToImage": "Immagine a immagine (Spandrel)"
|
||||
"noMatchingModels": "Nessun modello corrispondente"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -512,7 +510,7 @@
|
||||
"perlinNoise": "Rumore Perlin",
|
||||
"type": "Tipo",
|
||||
"strength": "Forza",
|
||||
"upscaling": "Amplia",
|
||||
"upscaling": "Ampliamento",
|
||||
"scale": "Scala",
|
||||
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
|
||||
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
|
||||
@@ -595,7 +593,7 @@
|
||||
"globalPositivePromptPlaceholder": "Prompt positivo globale",
|
||||
"globalNegativePromptPlaceholder": "Prompt negativo globale",
|
||||
"processImage": "Elabora Immagine",
|
||||
"sendToUpscale": "Invia a Amplia",
|
||||
"sendToUpscale": "Invia a Ampliare",
|
||||
"postProcessing": "Post-elaborazione (Shift + U)"
|
||||
},
|
||||
"settings": {
|
||||
@@ -1422,7 +1420,7 @@
|
||||
"paramUpscaleMethod": {
|
||||
"heading": "Metodo di ampliamento",
|
||||
"paragraphs": [
|
||||
"Metodo utilizzato per ampliare l'immagine per la correzione ad alta risoluzione."
|
||||
"Metodo utilizzato per eseguire l'ampliamento dell'immagine per la correzione ad alta risoluzione."
|
||||
]
|
||||
},
|
||||
"patchmatchDownScaleSize": {
|
||||
@@ -1530,7 +1528,7 @@
|
||||
},
|
||||
"upscaleModel": {
|
||||
"paragraphs": [
|
||||
"Il modello di ampliamento, scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
||||
"Il modello di ampliamento (Upscale), scala l'immagine alle dimensioni di uscita prima di aggiungere i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
||||
],
|
||||
"heading": "Modello di ampliamento"
|
||||
},
|
||||
@@ -1722,27 +1720,26 @@
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Coda",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"upscaling": "Amplia",
|
||||
"upscaling": "Ampliamento",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
"creativity": "Creatività",
|
||||
"structure": "Struttura",
|
||||
"upscaleModel": "Modello di ampliamento",
|
||||
"upscaleModel": "Modello di Ampliamento",
|
||||
"scale": "Scala",
|
||||
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
|
||||
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
|
||||
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
|
||||
"upscaleModelDesc": "Modello per l'ampliamento (immagine a immagine)",
|
||||
"upscaleModelDesc": "Modello per l'ampliamento (da immagine a immagine)",
|
||||
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
|
||||
"missingUpscaleModel": "Modello per l’ampliamento mancante",
|
||||
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
|
||||
"postProcessingModel": "Modello di post-elaborazione",
|
||||
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine).",
|
||||
"exceedsMaxSize": "Le impostazioni di ampliamento superano il limite massimo delle dimensioni",
|
||||
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata.",
|
||||
"upscale": "Amplia"
|
||||
"exceedsMaxSizeDetails": "Il limite massimo di ampliamento è {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixel. Prova un'immagine più piccola o diminuisci la scala selezionata."
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invita collaboratori",
|
||||
@@ -1792,7 +1789,6 @@
|
||||
"positivePromptColumn": "'prompt' o 'positive_prompt'",
|
||||
"noTemplates": "Nessun modello",
|
||||
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
|
||||
"templateActions": "Azioni modello",
|
||||
"promptTemplateCleared": "Modello di prompt cancellato"
|
||||
"templateActions": "Azioni modello"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -501,8 +501,7 @@
|
||||
"noModelsInstalled": "Нет установленных моделей",
|
||||
"noModelsInstalledDesc1": "Установите модели с помощью",
|
||||
"noMatchingModels": "Нет подходящих моделей",
|
||||
"ipAdapters": "IP адаптеры",
|
||||
"starterModelsInModelManager": "Стартовые модели можно найти в Менеджере моделей"
|
||||
"ipAdapters": "IP адаптеры"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Изображения",
|
||||
@@ -1759,8 +1758,7 @@
|
||||
"postProcessingModel": "Модель постобработки",
|
||||
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
|
||||
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
|
||||
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img).",
|
||||
"upscale": "Увеличить"
|
||||
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
|
||||
},
|
||||
"stylePresets": {
|
||||
"noMatchingTemplates": "Нет подходящих шаблонов",
|
||||
@@ -1806,8 +1804,7 @@
|
||||
"noTemplates": "Нет шаблонов",
|
||||
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
|
||||
"searchByName": "Поиск по имени",
|
||||
"shared": "Общий",
|
||||
"promptTemplateCleared": "Шаблон запроса создан"
|
||||
"shared": "Общий"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Пригласите членов команды",
|
||||
|
||||
@@ -154,8 +154,7 @@
|
||||
"displaySearch": "显示搜索",
|
||||
"stretchToFit": "拉伸以适应",
|
||||
"exitCompare": "退出对比",
|
||||
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。",
|
||||
"go": "运行"
|
||||
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "快捷键",
|
||||
@@ -495,9 +494,7 @@
|
||||
"huggingFacePlaceholder": "所有者或模型名称",
|
||||
"huggingFaceRepoID": "HuggingFace仓库ID",
|
||||
"loraTriggerPhrases": "LoRA 触发词",
|
||||
"ipAdapters": "IP适配器",
|
||||
"spandrelImageToImage": "图生图(Spandrel)",
|
||||
"starterModelsInModelManager": "您可以在模型管理器中找到初始模型"
|
||||
"ipAdapters": "IP适配器"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "图像",
|
||||
@@ -698,9 +695,7 @@
|
||||
"outOfMemoryErrorDesc": "您当前的生成设置已超出系统处理能力.请调整设置后再次尝试.",
|
||||
"parametersSet": "参数已恢复",
|
||||
"errorCopied": "错误信息已复制",
|
||||
"modelImportCanceled": "模型导入已取消",
|
||||
"importFailed": "导入失败",
|
||||
"importSuccessful": "导入成功"
|
||||
"modelImportCanceled": "模型导入已取消"
|
||||
},
|
||||
"unifiedCanvas": {
|
||||
"layer": "图层",
|
||||
@@ -1710,55 +1705,12 @@
|
||||
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
|
||||
"mainModelDesc": "主模型(SD1.5或SDXL架构)",
|
||||
"exceedsMaxSize": "放大设置超出了最大尺寸限制",
|
||||
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择.",
|
||||
"upscale": "放大"
|
||||
"exceedsMaxSizeDetails": "最大放大限制是 {{maxUpscaleDimension}}x{{maxUpscaleDimension}} 像素.请尝试一个较小的图像或减少您的缩放选择."
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "邀请团队成员",
|
||||
"professional": "专业",
|
||||
"professionalUpsell": "可在 Invoke 的专业版中使用.点击此处或访问 invoke.com/pricing 了解更多详情.",
|
||||
"shareAccess": "共享访问权限"
|
||||
},
|
||||
"stylePresets": {
|
||||
"positivePrompt": "正向提示词",
|
||||
"preview": "预览",
|
||||
"deleteImage": "删除图像",
|
||||
"deleteTemplate": "删除模版",
|
||||
"deleteTemplate2": "您确定要删除这个模板吗?请注意,删除后无法恢复.",
|
||||
"importTemplates": "导入提示模板,支持CSV或JSON格式",
|
||||
"insertPlaceholder": "插入一个占位符",
|
||||
"myTemplates": "我的模版",
|
||||
"name": "名称",
|
||||
"type": "类型",
|
||||
"unableToDeleteTemplate": "无法删除提示模板",
|
||||
"updatePromptTemplate": "更新提示词模版",
|
||||
"exportPromptTemplates": "导出我的提示模板为CSV格式",
|
||||
"exportDownloaded": "导出已下载",
|
||||
"noMatchingTemplates": "无匹配的模版",
|
||||
"promptTemplatesDesc1": "提示模板可以帮助您在编写提示时添加预设的文本内容.",
|
||||
"promptTemplatesDesc3": "如果您没有使用占位符,那么模板的内容将会被添加到您提示的末尾.",
|
||||
"searchByName": "按名称搜索",
|
||||
"shared": "已分享",
|
||||
"sharedTemplates": "已分享的模版",
|
||||
"templateActions": "模版操作",
|
||||
"templateDeleted": "提示模版已删除",
|
||||
"toggleViewMode": "切换显示模式",
|
||||
"uploadImage": "上传图像",
|
||||
"active": "激活",
|
||||
"choosePromptTemplate": "选择提示词模板",
|
||||
"clearTemplateSelection": "清除模版选择",
|
||||
"copyTemplate": "拷贝模版",
|
||||
"createPromptTemplate": "创建提示词模版",
|
||||
"defaultTemplates": "默认模版",
|
||||
"editTemplate": "编辑模版",
|
||||
"exportFailed": "无法生成并下载CSV文件",
|
||||
"flatten": "将选定的模板内容合并到当前提示中",
|
||||
"negativePrompt": "反向提示词",
|
||||
"promptTemplateCleared": "提示模板已清除",
|
||||
"useForTemplate": "用于提示词模版",
|
||||
"viewList": "预览模版列表",
|
||||
"viewModeTooltip": "这是您的提示在当前选定的模板下的预览效果。如需编辑提示,请直接在文本框中点击进行修改.",
|
||||
"noTemplates": "无模版",
|
||||
"private": "私密"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
@@ -44,17 +43,10 @@ interface Props {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: TabName;
|
||||
destination?: TabName | undefined;
|
||||
}
|
||||
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
}: Props) => {
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const language = useAppSelector(selectLanguage);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -93,12 +85,6 @@ const App = ({
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||
}
|
||||
}, [dispatch, selectedStylePresetId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
|
||||
@@ -45,7 +45,6 @@ interface Props extends PropsWithChildren {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: TabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
@@ -67,7 +66,6 @@ const InvokeAIUI = ({
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@@ -229,7 +227,6 @@ const InvokeAIUI = ({
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
selectedStylePresetId={selectedStylePresetId}
|
||||
destination={destination}
|
||||
/>
|
||||
</AppDndContext>
|
||||
|
||||
@@ -21,16 +21,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
direction,
|
||||
shadows: {
|
||||
..._theme.shadows,
|
||||
selected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverSelected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverUnselected:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
selectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
||||
hoverSelectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
||||
},
|
||||
});
|
||||
}, [direction]);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
@@ -9,7 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addCancellationsListeners } from 'app/store/middleware/listenerMiddleware/listeners/cancellationsListeners';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
@@ -41,8 +40,6 @@ export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
const startAppListening = listenerMiddleware.startListening as AppStartListening;
|
||||
|
||||
export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
@@ -122,5 +119,3 @@ addDynamicPromptsListener(startAppListening);
|
||||
|
||||
addSetDefaultSettingsListener(startAppListening);
|
||||
// addControlAdapterPreprocessor(startAppListening);
|
||||
|
||||
addCancellationsListeners(startAppListening);
|
||||
|
||||
@@ -1,34 +1,37 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { canvasReset, rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { stagingAreaImageAccepted, stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
sessionStagingAreaImageAccepted,
|
||||
sessionStagingAreaReset,
|
||||
} from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('canvas');
|
||||
|
||||
const matchCanvasOrStagingAreaRest = isAnyOf(stagingAreaReset, canvasReset);
|
||||
|
||||
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: matchCanvasOrStagingAreaRest,
|
||||
actionCreator: sessionStagingAreaReset,
|
||||
effect: async (_, { dispatch }) => {
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.cancelByBatchDestination.initiate(
|
||||
{ destination: 'canvas' },
|
||||
queueApi.endpoints.cancelByBatchOrigin.initiate(
|
||||
{ origin: 'canvas' },
|
||||
{ fixedCacheKey: 'cancelByBatchOrigin' }
|
||||
)
|
||||
);
|
||||
const { canceled } = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
|
||||
if (canceled > 0) {
|
||||
log.debug(`Canceled ${canceled} canvas batches`);
|
||||
toast({
|
||||
@@ -49,11 +52,11 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: stagingAreaImageAccepted,
|
||||
actionCreator: sessionStagingAreaImageAccepted,
|
||||
effect: (action, api) => {
|
||||
const { index } = action.payload;
|
||||
const state = api.getState();
|
||||
const stagingAreaImage = state.canvasStagingArea.stagedImages[index];
|
||||
const stagingAreaImage = state.canvasSession.stagedImages[index];
|
||||
|
||||
assert(stagingAreaImage, 'No staged image found to accept');
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
@@ -66,7 +69,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
};
|
||||
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: false }));
|
||||
api.dispatch(stagingAreaReset());
|
||||
api.dispatch(sessionStagingAreaReset());
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { $lastCanvasProgressEvent } from 'features/controlLayers/store/canvasSlice';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
/**
|
||||
* To prevent a race condition where a progress event arrives after a successful cancellation, we need to keep track of
|
||||
* cancellations:
|
||||
* - In the route handlers above, we track and update the cancellations object
|
||||
* - When the user queues a, we should reset the cancellations, also handled int he route handlers above
|
||||
* - When we get a progress event, we should check if the event is cancelled before setting the event
|
||||
*
|
||||
* We have a few ways that cancellations are effected, so we need to track them all:
|
||||
* - by queue item id (in this case, we will compare the session_id and not the item_id)
|
||||
* - by batch id
|
||||
* - by destination
|
||||
* - by clearing the queue
|
||||
*/
|
||||
type Cancellations = {
|
||||
sessionIds: Set<string>;
|
||||
batchIds: Set<string>;
|
||||
destinations: Set<string>;
|
||||
clearQueue: boolean;
|
||||
};
|
||||
|
||||
const resetCancellations = (): void => {
|
||||
cancellations.clearQueue = false;
|
||||
cancellations.sessionIds.clear();
|
||||
cancellations.batchIds.clear();
|
||||
cancellations.destinations.clear();
|
||||
};
|
||||
|
||||
const cancellations: Cancellations = {
|
||||
sessionIds: new Set(),
|
||||
batchIds: new Set(),
|
||||
destinations: new Set(),
|
||||
clearQueue: false,
|
||||
} as Readonly<Cancellations>;
|
||||
|
||||
/**
|
||||
* Checks if an item is cancelled, used to prevent race conditions with event handling.
|
||||
*
|
||||
* To use this, provide the session_id, batch_id and destination from the event payload.
|
||||
*/
|
||||
export const getIsCancelled = (item: {
|
||||
session_id: string;
|
||||
batch_id: string;
|
||||
destination?: string | null;
|
||||
}): boolean => {
|
||||
if (cancellations.clearQueue) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.sessionIds.has(item.session_id)) {
|
||||
return true;
|
||||
}
|
||||
if (cancellations.batchIds.has(item.batch_id)) {
|
||||
return true;
|
||||
}
|
||||
if (item.destination && cancellations.destinations.has(item.destination)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
export const addCancellationsListeners = (startAppListening: AppStartListening) => {
|
||||
// When we get a cancellation, we may need to clear the last progress event - next few listeners handle those cases.
|
||||
// Maybe we could use the `getIsCancelled` util here, but I think that could introduce _another_ race condition...
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: () => {
|
||||
resetCancellations();
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchDestination.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.destinations.add(action.meta.arg.originalArgs.destination);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
effect: (action) => {
|
||||
cancellations.sessionIds.add(action.payload.session_id);
|
||||
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.cancelByBatchIds.matchFulfilled,
|
||||
effect: (action) => {
|
||||
for (const batch_id of action.meta.arg.originalArgs.batch_ids) {
|
||||
cancellations.batchIds.add(batch_id);
|
||||
}
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.clearQueue.matchFulfilled,
|
||||
effect: () => {
|
||||
cancellations.clearQueue = true;
|
||||
const event = $lastCanvasProgressEvent.get();
|
||||
if (!event) {
|
||||
return;
|
||||
}
|
||||
const { session_id, batch_id, destination } = event;
|
||||
if (getIsCancelled({ session_id, batch_id, destination })) {
|
||||
$lastCanvasProgressEvent.set(null);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -4,12 +4,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import type { Result } from 'common/util/result';
|
||||
import { isErr, withResult, withResultAsync } from 'common/util/result';
|
||||
import { $canvasManager } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
selectIsStaging,
|
||||
stagingAreaReset,
|
||||
stagingAreaStartedStaging,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
@@ -35,14 +31,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
let didStartStaging = false;
|
||||
|
||||
if (!selectIsStaging(state) && state.canvasSettings.sendToCanvas) {
|
||||
dispatch(stagingAreaStartedStaging());
|
||||
if (!state.canvasSession.isStaging && state.canvasSession.sendToCanvas) {
|
||||
dispatch(sessionStartedStaging());
|
||||
didStartStaging = true;
|
||||
}
|
||||
|
||||
const abortStaging = () => {
|
||||
if (didStartStaging && selectIsStaging(getState())) {
|
||||
dispatch(stagingAreaReset());
|
||||
if (didStartStaging && getState().canvasSession.isStaging) {
|
||||
dispatch(sessionStagingAreaReset());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -74,7 +70,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
const { g, noise, posCond } = buildGraphResult.value;
|
||||
|
||||
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
|
||||
const destination = state.canvasSession.sendToCanvas ? 'canvas' : 'gallery';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
@@ -10,6 +11,7 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { shouldShowProgressInViewer } = state.ui;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||
@@ -23,6 +25,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
);
|
||||
try {
|
||||
await req.unwrap();
|
||||
if (shouldShowProgressInViewer) {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
//TODO(psyche): handle image deletion (canvas staging area?)
|
||||
//TODO(psyche): handle image deletion (canvas sessions?)
|
||||
|
||||
// Some utils to delete images from different parts of the app
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
ipaImageChanged,
|
||||
@@ -13,7 +12,7 @@ import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/c
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -104,14 +103,11 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const state = getState();
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
const defaultControlAdapter = selectDefaultControlAdapter(state);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
controlAdapter: defaultControlAdapter,
|
||||
};
|
||||
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||
return;
|
||||
@@ -146,6 +142,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -45,8 +44,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
if (!autoAddBoardId || autoAddBoardId === 'none') {
|
||||
const title = postUploadAction.title || DEFAULT_UPLOADED_TOAST.title;
|
||||
toast({ ...DEFAULT_UPLOADED_TOAST, title });
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
} else {
|
||||
// Add this image to the board
|
||||
dispatch(
|
||||
@@ -70,8 +67,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description,
|
||||
});
|
||||
dispatch(boardIdSelected({ boardId: autoAddBoardId }));
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { calculateNewSize } from 'features/parameters/components/Bbox/calculateNewSize';
|
||||
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
|
||||
@@ -6,14 +6,12 @@ import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
|
||||
import type { SerializableObject } from 'common/types';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
canvasStagingAreaPersistConfig,
|
||||
canvasStagingAreaSlice,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice';
|
||||
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
|
||||
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
@@ -65,8 +63,9 @@ const allReducers = {
|
||||
[upscaleSlice.name]: upscaleSlice.reducer,
|
||||
[stylePresetSlice.name]: stylePresetSlice.reducer,
|
||||
[paramsSlice.name]: paramsSlice.reducer,
|
||||
[toolSlice.name]: toolSlice.reducer,
|
||||
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
|
||||
[canvasStagingAreaSlice.name]: canvasStagingAreaSlice.reducer,
|
||||
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
};
|
||||
|
||||
@@ -110,8 +109,9 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[upscalePersistConfig.name]: upscalePersistConfig,
|
||||
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
|
||||
[paramsPersistConfig.name]: paramsPersistConfig,
|
||||
[toolPersistConfig.name]: toolPersistConfig,
|
||||
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
|
||||
[canvasStagingAreaPersistConfig.name]: canvasStagingAreaPersistConfig,
|
||||
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
};
|
||||
|
||||
|
||||
@@ -6,51 +6,18 @@ import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
|
||||
import IAIDraggable from './IAIDraggable';
|
||||
import IAIDroppable from './IAIDroppable';
|
||||
import SelectionOverlay from './SelectionOverlay';
|
||||
|
||||
const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
|
||||
|
||||
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
'.gallery-image-container::before': {
|
||||
content: '""',
|
||||
display: 'inline-block',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
pointerEvents: 'none',
|
||||
borderRadius: 'base',
|
||||
},
|
||||
'&[data-selected="selected"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&[data-selected="selectedForCompare"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
'&:hover>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected="selected"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected="selectedForCompare"]>.gallery-image-container::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
};
|
||||
|
||||
type IAIDndImageProps = FlexProps & {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||
@@ -108,11 +75,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const handleMouseOver = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOver) {
|
||||
onMouseOver(e);
|
||||
}
|
||||
setIsHovered(true);
|
||||
},
|
||||
[onMouseOver]
|
||||
);
|
||||
@@ -121,6 +90,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
if (onMouseOut) {
|
||||
onMouseOut(e);
|
||||
}
|
||||
setIsHovered(false);
|
||||
},
|
||||
[onMouseOut]
|
||||
);
|
||||
@@ -171,13 +141,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
minH={minSize ? minSize : undefined}
|
||||
userSelect="none"
|
||||
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
|
||||
sx={withHoverOverlay ? sx : undefined}
|
||||
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
|
||||
{...rest}
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
className="gallery-image-container"
|
||||
w="full"
|
||||
h="full"
|
||||
position={fitContainer ? 'absolute' : 'relative'}
|
||||
@@ -200,6 +167,11 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
<SelectionOverlay
|
||||
isSelected={isSelected}
|
||||
isSelectedForCompare={isSelectedForCompare}
|
||||
isHovered={withHoverOverlay ? isHovered : false}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
|
||||
104
invokeai/frontend/web/src/common/components/IconSwitch.tsx
Normal file
104
invokeai/frontend/web/src/common/components/IconSwitch.tsx
Normal file
@@ -0,0 +1,104 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Tooltip, useToken } from '@invoke-ai/ui-library';
|
||||
import type { ReactElement, ReactNode } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type IconSwitchProps = {
|
||||
isChecked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
iconChecked: ReactElement;
|
||||
tooltipChecked?: ReactNode;
|
||||
iconUnchecked: ReactElement;
|
||||
tooltipUnchecked?: ReactNode;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
const getSx = (padding: string | number): SystemStyleObject => ({
|
||||
transition: 'left 0.1s ease-in-out, transform 0.1s ease-in-out',
|
||||
'&[data-checked="true"]': {
|
||||
left: `calc(100% - ${padding})`,
|
||||
transform: 'translateX(-100%)',
|
||||
},
|
||||
'&[data-checked="false"]': {
|
||||
left: padding,
|
||||
transform: 'translateX(0)',
|
||||
},
|
||||
});
|
||||
|
||||
export const IconSwitch = memo(
|
||||
({
|
||||
isChecked,
|
||||
onChange,
|
||||
iconChecked,
|
||||
tooltipChecked,
|
||||
iconUnchecked,
|
||||
tooltipUnchecked,
|
||||
ariaLabel,
|
||||
}: IconSwitchProps) => {
|
||||
const onUncheck = useCallback(() => {
|
||||
onChange(false);
|
||||
}, [onChange]);
|
||||
const onCheck = useCallback(() => {
|
||||
onChange(true);
|
||||
}, [onChange]);
|
||||
|
||||
const gap = useToken('space', 1.5);
|
||||
const sx = useMemo(() => getSx(gap), [gap]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
h="full"
|
||||
p={gap}
|
||||
gap={gap}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
borderRadius="base"
|
||||
bg="invokeBlue.400"
|
||||
w={12}
|
||||
top={gap}
|
||||
bottom={gap}
|
||||
data-checked={isChecked}
|
||||
sx={sx}
|
||||
/>
|
||||
<Tooltip hasArrow label={tooltipUnchecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconUnchecked}
|
||||
onClick={onUncheck}
|
||||
variant={!isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={!isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={!isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip hasArrow label={tooltipChecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconChecked}
|
||||
onClick={onCheck}
|
||||
variant={isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IconSwitch.displayName = 'IconSwitch';
|
||||
@@ -0,0 +1,46 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isSelected: boolean;
|
||||
isSelectedForCompare: boolean;
|
||||
isHovered: boolean;
|
||||
};
|
||||
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
|
||||
const shadow = useMemo(() => {
|
||||
if (isSelectedForCompare && isHovered) {
|
||||
return 'hoverSelectedForCompare';
|
||||
}
|
||||
if (isSelectedForCompare && !isHovered) {
|
||||
return 'selectedForCompare';
|
||||
}
|
||||
if (isSelected && isHovered) {
|
||||
return 'hoverSelected';
|
||||
}
|
||||
if (isSelected && !isHovered) {
|
||||
return 'selected';
|
||||
}
|
||||
if (!isSelected && isHovered) {
|
||||
return 'hoverUnselected';
|
||||
}
|
||||
return undefined;
|
||||
}, [isHovered, isSelected, isSelectedForCompare]);
|
||||
return (
|
||||
<Box
|
||||
className="selection-box"
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineEnd={0}
|
||||
bottom={0}
|
||||
insetInlineStart={0}
|
||||
borderRadius="base"
|
||||
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
pointerEvents="none"
|
||||
shadow={shadow}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SelectionOverlay);
|
||||
@@ -1,74 +1,52 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
|
||||
type UseBoolean = {
|
||||
isTrue: boolean;
|
||||
setTrue: () => void;
|
||||
setFalse: () => void;
|
||||
set: (value: boolean) => void;
|
||||
toggle: () => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a hook to manage a boolean state. The boolean is stored in a nanostores atom.
|
||||
* Returns a tuple containing the hook and the atom. Use this for global boolean state.
|
||||
* @param initialValue Initial value of the boolean
|
||||
*/
|
||||
export const buildUseBoolean = (initialValue: boolean): [() => UseBoolean, WritableAtom<boolean>] => {
|
||||
const $boolean = atom(initialValue);
|
||||
|
||||
const setTrue = () => {
|
||||
$boolean.set(true);
|
||||
};
|
||||
const setFalse = () => {
|
||||
$boolean.set(false);
|
||||
};
|
||||
const set = (value: boolean) => {
|
||||
$boolean.set(value);
|
||||
};
|
||||
const toggle = () => {
|
||||
$boolean.set(!$boolean.get());
|
||||
};
|
||||
|
||||
const useBoolean = () => {
|
||||
const isTrue = useStore($boolean);
|
||||
|
||||
return {
|
||||
isTrue,
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
};
|
||||
};
|
||||
|
||||
return [useBoolean, $boolean] as const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook to manage a boolean state. Use this for a local boolean state.
|
||||
* @param initialValue Initial value of the boolean
|
||||
*/
|
||||
export const useBoolean = (initialValue: boolean) => {
|
||||
const [isTrue, set] = useState(initialValue);
|
||||
const setTrue = useCallback(() => set(true), []);
|
||||
const setFalse = useCallback(() => set(false), []);
|
||||
const toggle = useCallback(() => set((v) => !v), []);
|
||||
|
||||
const setTrue = useCallback(() => {
|
||||
set(true);
|
||||
}, [set]);
|
||||
const setFalse = useCallback(() => {
|
||||
set(false);
|
||||
}, [set]);
|
||||
const toggle = useCallback(() => {
|
||||
set((val) => !val);
|
||||
}, [set]);
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
isTrue,
|
||||
set,
|
||||
setTrue,
|
||||
setFalse,
|
||||
toggle,
|
||||
}),
|
||||
[isTrue, set, setTrue, setFalse, toggle]
|
||||
);
|
||||
|
||||
return {
|
||||
isTrue,
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
return api;
|
||||
};
|
||||
|
||||
export const buildUseBoolean = ($boolean: WritableAtom<boolean>) => {
|
||||
return () => {
|
||||
const setTrue = useCallback(() => {
|
||||
$boolean.set(true);
|
||||
}, []);
|
||||
const setFalse = useCallback(() => {
|
||||
$boolean.set(false);
|
||||
}, []);
|
||||
const set = useCallback((value: boolean) => {
|
||||
$boolean.set(value);
|
||||
}, []);
|
||||
const toggle = useCallback(() => {
|
||||
$boolean.set(!$boolean.get());
|
||||
}, []);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
$boolean,
|
||||
}),
|
||||
[set, setFalse, setTrue, toggle]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -114,13 +114,4 @@ export const useGlobalHotkeys = () => {
|
||||
},
|
||||
[dispatch, isModelManagerEnabled]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
isModelManagerEnabled ? '6' : '5',
|
||||
() => {
|
||||
dispatch(setActiveTab('gallery'));
|
||||
setScopes([]);
|
||||
},
|
||||
[dispatch, isModelManagerEnabled]
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,7 +2,6 @@ import { useStore } from '@nanostores/react';
|
||||
import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
@@ -18,7 +17,6 @@ import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { useMemo } from 'react';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
|
||||
@@ -30,7 +28,7 @@ const LAYER_TYPE_TO_TKEY = {
|
||||
control_layer: 'controlLayers.globalControlAdapter',
|
||||
} as const;
|
||||
|
||||
const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy: boolean) =>
|
||||
const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
createMemoizedSelector(
|
||||
[
|
||||
selectSystemSlice,
|
||||
@@ -119,10 +117,6 @@ const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
}
|
||||
} else {
|
||||
if (canvasIsBusy) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.canvasBusy') });
|
||||
}
|
||||
|
||||
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
|
||||
}
|
||||
@@ -134,7 +128,7 @@ const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy
|
||||
canvas.controlLayers.entities
|
||||
.filter((controlLayer) => controlLayer.isEnabled)
|
||||
.forEach((controlLayer, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -164,7 +158,7 @@ const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy
|
||||
canvas.ipAdapters.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -192,7 +186,7 @@ const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy
|
||||
canvas.regions.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -229,7 +223,7 @@ const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy
|
||||
canvas.rasterLayers.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -246,17 +240,10 @@ const createSelector = (templates: Templates, isConnected: boolean, canvasIsBusy
|
||||
}
|
||||
);
|
||||
|
||||
const dummyAtom = atom(true);
|
||||
|
||||
export const useIsReadyToEnqueue = () => {
|
||||
const templates = useStore($templates);
|
||||
const isConnected = useStore($isConnected);
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
const canvasIsBusy = useStore(canvasManager?.$isBusy ?? dummyAtom);
|
||||
const selector = useMemo(
|
||||
() => createSelector(templates, isConnected, canvasIsBusy),
|
||||
[templates, isConnected, canvasIsBusy]
|
||||
);
|
||||
const selector = useMemo(() => createSelector(templates, isConnected), [templates, isConnected]);
|
||||
const value = useAppSelector(selector);
|
||||
return value;
|
||||
};
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
export const roundDownToMultiple = (num: number, multiple: number): number => {
|
||||
return Math.floor(num / multiple) * multiple;
|
||||
};
|
||||
export const roundUpToMultiple = (num: number, multiple: number): number => {
|
||||
return Math.ceil(num / multiple) * multiple;
|
||||
};
|
||||
|
||||
export const roundToMultiple = (num: number, multiple: number): number => {
|
||||
return Math.round(num / multiple) * multiple;
|
||||
|
||||
@@ -1,22 +1,34 @@
|
||||
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
useAddControlLayer,
|
||||
useAddInpaintMask,
|
||||
useAddIPAdapter,
|
||||
useAddRasterLayer,
|
||||
useAddRegionalGuidance,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { memo } from 'react';
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasAddEntityButtons = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const addInpaintMask = useAddInpaintMask();
|
||||
const addRegionalGuidance = useAddRegionalGuidance();
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const addIPAdapter = useAddIPAdapter();
|
||||
const dispatch = useAppDispatch();
|
||||
const addInpaintMask = useCallback(() => {
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRegionalGuidance = useCallback(() => {
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRasterLayer = useCallback(() => {
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addControlLayer = useCallback(() => {
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(ipaAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center">
|
||||
@@ -34,7 +46,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.globalIPAdapter')}
|
||||
{t('controlLayers.ipAdapter')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type { AddControlLayerFromImageDropData, AddRasterLayerFromImageDropData } from 'features/dnd/types';
|
||||
import { useIsImageViewerOpen } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo } from 'react';
|
||||
|
||||
const addRasterLayerFromImageDropData: AddRasterLayerFromImageDropData = {
|
||||
@@ -14,6 +15,12 @@ const addControlLayerFromImageDropData: AddControlLayerFromImageDropData = {
|
||||
};
|
||||
|
||||
export const CanvasDropArea = memo(() => {
|
||||
const isImageViewerOpen = useIsImageViewerOpen();
|
||||
|
||||
if (isImageViewerOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex position="absolute" top={0} right={0} bottom="50%" left={0} gap={2} pointerEvents="none">
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { EntityListActionBarAddLayerButton } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarAddLayerMenuButton';
|
||||
import { EntityListActionBarDeleteButton } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarDeleteButton';
|
||||
import { EntityListActionBarSelectedEntityFill } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarSelectedEntityFill';
|
||||
import { SelectedEntityOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarSelectedEntityOpacity';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EntityListActionBar = memo(() => {
|
||||
return (
|
||||
<Flex w="full" py={1} px={1} gap={2} alignItems="center">
|
||||
<SelectedEntityOpacity />
|
||||
<Spacer />
|
||||
<EntityListActionBarSelectedEntityFill />
|
||||
<EntityListActionBarAddLayerButton />
|
||||
<EntityListActionBarDeleteButton />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBar.displayName = 'EntityListActionBar';
|
||||
@@ -0,0 +1,28 @@
|
||||
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityListMenuItems } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarAddLayerMenuItems';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const EntityListActionBarAddLayerButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
size="sm"
|
||||
tooltip={t('controlLayers.addLayer')}
|
||||
aria-label={t('controlLayers.addLayer')}
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
data-testid="control-layers-add-layer-menu-button"
|
||||
/>
|
||||
<MenuList>
|
||||
<CanvasEntityListMenuItems />
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBarAddLayerButton.displayName = 'EntityListActionBarAddLayerButton';
|
||||
@@ -0,0 +1,54 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import {
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityListMenuItems = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const addInpaintMask = useCallback(() => {
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRegionalGuidance = useCallback(() => {
|
||||
dispatch(rgAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addRasterLayer = useCallback(() => {
|
||||
dispatch(rasterLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addControlLayer = useCallback(() => {
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
dispatch(ipaAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.ipAdapter')}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasEntityListMenuItems.displayName = 'CanvasEntityListMenu';
|
||||
@@ -0,0 +1,39 @@
|
||||
import { IconButton, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { allEntitiesDeleted, entityDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectEntityCount, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
|
||||
export const EntityListActionBarDeleteButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const entityCount = useAppSelector(selectEntityCount);
|
||||
const shift = useShiftModifier();
|
||||
const onClick = useCallback(() => {
|
||||
if (shift) {
|
||||
dispatch(allEntitiesDeleted());
|
||||
return;
|
||||
}
|
||||
if (!selectedEntityIdentifier) {
|
||||
return;
|
||||
}
|
||||
dispatch(entityDeleted({ entityIdentifier: selectedEntityIdentifier }));
|
||||
}, [dispatch, selectedEntityIdentifier, shift]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
isDisabled={shift ? entityCount === 0 : !selectedEntityIdentifier}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
aria-label={shift ? t('controlLayers.deleteAll') : t('controlLayers.deleteSelected')}
|
||||
tooltip={shift ? t('controlLayers.deleteAll') : t('controlLayers.deleteSelected')}
|
||||
icon={<PiTrashSimpleFill />}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBarDeleteButton.displayName = 'EntityListActionBarDeleteButton';
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user