mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 04:17:59 -05:00
Compare commits
507 Commits
v4.2.9.dev
...
v4.2.9.dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
434afdd21d | ||
|
|
ea267c0e09 | ||
|
|
35b483b83a | ||
|
|
217b2759c3 | ||
|
|
7d528dc03d | ||
|
|
bc40c1a99f | ||
|
|
a2b508016e | ||
|
|
3ce8294379 | ||
|
|
9c3da8de8e | ||
|
|
ccba597e58 | ||
|
|
8028943cdd | ||
|
|
cdd8b60fd0 | ||
|
|
c624754404 | ||
|
|
0505634296 | ||
|
|
e06ea5d595 | ||
|
|
838a0574a5 | ||
|
|
ed3581c70c | ||
|
|
522cae6a42 | ||
|
|
a9032a34f2 | ||
|
|
6af5a22d3a | ||
|
|
8cf4321010 | ||
|
|
aa7f2b096a | ||
|
|
bf0824b56d | ||
|
|
056c56d322 | ||
|
|
afc6f83d72 | ||
|
|
c776ac3af2 | ||
|
|
b7b3683bef | ||
|
|
fb26b6824a | ||
|
|
63d8ad912f | ||
|
|
bbd7d7fc17 | ||
|
|
6507a78182 | ||
|
|
22f46517f4 | ||
|
|
45596e1f94 | ||
|
|
6de0dbe854 | ||
|
|
011827fa29 | ||
|
|
fc6d244071 | ||
|
|
cd3da886d6 | ||
|
|
c013c55d92 | ||
|
|
cd3dd7db0d | ||
|
|
1fdcce9429 | ||
|
|
181e40926d | ||
|
|
c62ede5878 | ||
|
|
a2ad5f1a9a | ||
|
|
ff74a5356f | ||
|
|
92dc30dace | ||
|
|
3af577b210 | ||
|
|
d0464330f7 | ||
|
|
dd3ef4a80f | ||
|
|
0ced891944 | ||
|
|
10a5452df9 | ||
|
|
cb97969bbc | ||
|
|
71e742e238 | ||
|
|
fadd20fb8e | ||
|
|
01b9ca78e4 | ||
|
|
2baf825f34 | ||
|
|
1fa8048509 | ||
|
|
a000ad75f6 | ||
|
|
aefb2339bb | ||
|
|
a4f8671f86 | ||
|
|
73530ba54f | ||
|
|
685eb9927d | ||
|
|
ee57302fc3 | ||
|
|
c1fb9cdb93 | ||
|
|
aa6d441552 | ||
|
|
25d8d4c2e9 | ||
|
|
427ea6da5c | ||
|
|
d9f4266630 | ||
|
|
96f6e9e683 | ||
|
|
f10248e3f5 | ||
|
|
6a21f5fde1 | ||
|
|
ff20dd509a | ||
|
|
78a59b5b78 | ||
|
|
46bfbbbc87 | ||
|
|
a6d73d0773 | ||
|
|
6578e8bef8 | ||
|
|
0596d25e07 | ||
|
|
86e8ce9139 | ||
|
|
5aa2957da4 | ||
|
|
82f0cb2c8c | ||
|
|
fa48145cbc | ||
|
|
6d1edc330d | ||
|
|
97c0d3f6be | ||
|
|
a79a25ad63 | ||
|
|
6a8ceef404 | ||
|
|
3539670d93 | ||
|
|
c54bc32ef6 | ||
|
|
fee293e289 | ||
|
|
747eef9ccc | ||
|
|
7d2df399ed | ||
|
|
68fad5cdcc | ||
|
|
b4d656c203 | ||
|
|
3136d89d52 | ||
|
|
27e829b955 | ||
|
|
e03e870d5b | ||
|
|
9465ff450b | ||
|
|
92906a9575 | ||
|
|
77f206abe4 | ||
|
|
44a3f61580 | ||
|
|
0a2afed08b | ||
|
|
9b3b961105 | ||
|
|
9b1828e1aa | ||
|
|
5101873f49 | ||
|
|
c612f18114 | ||
|
|
7e400d876f | ||
|
|
677dddcfc9 | ||
|
|
0792b9175e | ||
|
|
e4829f80af | ||
|
|
bb760f3eb4 | ||
|
|
388c65287b | ||
|
|
12cd41e05c | ||
|
|
7765c03949 | ||
|
|
3daa80c57f | ||
|
|
5dbbef4ebd | ||
|
|
db33b3f7b5 | ||
|
|
8ffcf2a6be | ||
|
|
2e7ae6a07e | ||
|
|
fea1711f0c | ||
|
|
2a3546db97 | ||
|
|
285c266612 | ||
|
|
426ad54c53 | ||
|
|
fc75f7919f | ||
|
|
6c6b1aaff6 | ||
|
|
c319d653ac | ||
|
|
d887e474e7 | ||
|
|
da7b52d6ba | ||
|
|
b5aa308593 | ||
|
|
0b7ceb3bb6 | ||
|
|
3a70cefda2 | ||
|
|
4b609251e1 | ||
|
|
0839eac0f7 | ||
|
|
5f2a7feeee | ||
|
|
982535eb92 | ||
|
|
0c2b8edc8d | ||
|
|
f78f4ca25f | ||
|
|
d6b3e6c07d | ||
|
|
071ff8e74a | ||
|
|
1ea8aafca1 | ||
|
|
533dd221f8 | ||
|
|
2b325c6683 | ||
|
|
3845b1b3e6 | ||
|
|
cea7890a67 | ||
|
|
c38fe8025d | ||
|
|
f1de95349c | ||
|
|
2950775fa7 | ||
|
|
cb293fd7ac | ||
|
|
43b3fab6be | ||
|
|
d4b0dbce49 | ||
|
|
137b810669 | ||
|
|
c172657324 | ||
|
|
97c966b04f | ||
|
|
7178fc6253 | ||
|
|
4adb2eabf5 | ||
|
|
9f2c815e13 | ||
|
|
1435557d1d | ||
|
|
96abf687f6 | ||
|
|
636d9a7209 | ||
|
|
3b36eb0223 | ||
|
|
388c97bff0 | ||
|
|
b1cb018695 | ||
|
|
df78dd7953 | ||
|
|
0dc344a22e | ||
|
|
350d7f6f14 | ||
|
|
11059ee2d4 | ||
|
|
c90d3f3bb9 | ||
|
|
7f6d439fd1 | ||
|
|
783a78f069 | ||
|
|
0ff031950d | ||
|
|
d7e8f3d756 | ||
|
|
4668ea449b | ||
|
|
30d318d021 | ||
|
|
de96f97e5f | ||
|
|
57c0a2dfb1 | ||
|
|
cd4e464bde | ||
|
|
49e48c3eb7 | ||
|
|
edd3b3bce9 | ||
|
|
f8bfb66108 | ||
|
|
3b6a76cbf3 | ||
|
|
e0b60e4320 | ||
|
|
2159319035 | ||
|
|
b170fc232e | ||
|
|
594da60f2f | ||
|
|
6a432f6518 | ||
|
|
eb8eacfec6 | ||
|
|
c8d04d42e2 | ||
|
|
d39c9de81e | ||
|
|
a27d39b9ff | ||
|
|
6b385614f0 | ||
|
|
3ae7250ef7 | ||
|
|
a42d0ce1d2 | ||
|
|
d9131f7563 | ||
|
|
bdce958f29 | ||
|
|
3c86f1e979 | ||
|
|
894b8a29b9 | ||
|
|
8436a44973 | ||
|
|
f9f9ec3688 | ||
|
|
5a98d7a1f6 | ||
|
|
f9bc96e497 | ||
|
|
56350ff5dc | ||
|
|
6c1139340c | ||
|
|
641b1a7e6f | ||
|
|
674a3f462f | ||
|
|
2283186d3a | ||
|
|
340af1fe50 | ||
|
|
9378656d78 | ||
|
|
5f0413e222 | ||
|
|
c3e47515b1 | ||
|
|
5dcef6fa0d | ||
|
|
31ac02cd93 | ||
|
|
ab16976084 | ||
|
|
8e2b7845e1 | ||
|
|
3973bce342 | ||
|
|
f63847a504 | ||
|
|
07e3529948 | ||
|
|
03e1c60694 | ||
|
|
d766ed71fc | ||
|
|
ae68ef142a | ||
|
|
20f55768c4 | ||
|
|
c4fad4456e | ||
|
|
78f5ec44ad | ||
|
|
e14ba86942 | ||
|
|
d4e7720f6b | ||
|
|
a3f0e7e1cb | ||
|
|
30a696c476 | ||
|
|
66d6c64e16 | ||
|
|
d15be9b57c | ||
|
|
e5da902fd0 | ||
|
|
fc558094c2 | ||
|
|
ad9312e989 | ||
|
|
8e1a70b008 | ||
|
|
17f88cd5ad | ||
|
|
298f1919fa | ||
|
|
4d20cc11d4 | ||
|
|
14f249a2f0 | ||
|
|
b9746a6c2c | ||
|
|
94f298a6f4 | ||
|
|
8d3a8178da | ||
|
|
cad4212fe8 | ||
|
|
cff28dfaa9 | ||
|
|
70d7509fcc | ||
|
|
cf83af7a27 | ||
|
|
5c5a405c0f | ||
|
|
0208e4b232 | ||
|
|
e940754795 | ||
|
|
dc9fa1a735 | ||
|
|
08591fbf6d | ||
|
|
74db71bb5d | ||
|
|
60dbe798a5 | ||
|
|
0e676605fe | ||
|
|
3f781016f6 | ||
|
|
17cd2f6b02 | ||
|
|
99102a1b34 | ||
|
|
8d72e7d9e8 | ||
|
|
0b6b6f97ad | ||
|
|
fb2f6382b1 | ||
|
|
1ddea87c35 | ||
|
|
ea02323095 | ||
|
|
49733091c7 | ||
|
|
cf833fd6e2 | ||
|
|
ba5cf07ab8 | ||
|
|
d15321a373 | ||
|
|
de597a5eb4 | ||
|
|
e5f5cbdf5c | ||
|
|
7d4342bbff | ||
|
|
7f8a1d8d20 | ||
|
|
65353ac1e1 | ||
|
|
7f9a31ca4a | ||
|
|
592eb2886c | ||
|
|
c220dd8987 | ||
|
|
a263beb0d5 | ||
|
|
46b7c510eb | ||
|
|
f405e472ea | ||
|
|
7bdfd3ef5f | ||
|
|
778ee2c679 | ||
|
|
e70339ff3e | ||
|
|
88c57a9750 | ||
|
|
137252128b | ||
|
|
d4297b1345 | ||
|
|
6059bc7b47 | ||
|
|
c3ff3eb51f | ||
|
|
0b7751c413 | ||
|
|
d7f1c30624 | ||
|
|
3f4d7dbeea | ||
|
|
19b6ae2907 | ||
|
|
769f96ff9f | ||
|
|
fdaf75faa4 | ||
|
|
1380bb7ae6 | ||
|
|
9483c8cc29 | ||
|
|
2ef8a8cf5a | ||
|
|
d296ec1932 | ||
|
|
444ad3dae1 | ||
|
|
8cdcc71378 | ||
|
|
e8bc06cfd3 | ||
|
|
67a0a024e9 | ||
|
|
bd2c46c267 | ||
|
|
5acb27f350 | ||
|
|
7271b12d0f | ||
|
|
4a79467a33 | ||
|
|
5501bb87a3 | ||
|
|
561610e296 | ||
|
|
b76609ef18 | ||
|
|
070b78501b | ||
|
|
50df4f4ab6 | ||
|
|
9bbf430125 | ||
|
|
384a90958a | ||
|
|
0e4a25b029 | ||
|
|
4a44e171fd | ||
|
|
9bc57a6f59 | ||
|
|
4341ed7ab4 | ||
|
|
97ce72c542 | ||
|
|
a2c78a57a7 | ||
|
|
044a713dc9 | ||
|
|
b8479c5fe2 | ||
|
|
4e5d056824 | ||
|
|
118278b372 | ||
|
|
8e8c255f3f | ||
|
|
1575bee401 | ||
|
|
249bbfc883 | ||
|
|
3993ae410f | ||
|
|
edf040e3d2 | ||
|
|
66fd077ee7 | ||
|
|
b93462ebb6 | ||
|
|
aae60d0cdc | ||
|
|
d4da00e607 | ||
|
|
0c539ff00b | ||
|
|
5983cbf26c | ||
|
|
c513d6e3af | ||
|
|
9d57c0e631 | ||
|
|
a1923a8966 | ||
|
|
d988e18731 | ||
|
|
51008da2dd | ||
|
|
6ccc1f5672 | ||
|
|
4a556f84e0 | ||
|
|
2f21a2220d | ||
|
|
91a420b13e | ||
|
|
c27da3581b | ||
|
|
961dfbce93 | ||
|
|
df39c825ae | ||
|
|
3f6ee1b7a4 | ||
|
|
908e504a6f | ||
|
|
f2fa41afc5 | ||
|
|
440ff40ad5 | ||
|
|
5c15458e15 | ||
|
|
be5b474f1e | ||
|
|
cee178c2b6 | ||
|
|
27657f8b7a | ||
|
|
e0cde3815a | ||
|
|
09d0421de4 | ||
|
|
47b94d563c | ||
|
|
0b5d20c9f0 | ||
|
|
80e7e1293a | ||
|
|
3a82b0cbc1 | ||
|
|
a27cbc13b6 | ||
|
|
a8f962eb3f | ||
|
|
7f40d23f19 | ||
|
|
918354cd9d | ||
|
|
eef9278ee6 | ||
|
|
2c32e2e5c1 | ||
|
|
6f05654db5 | ||
|
|
1d31b6902f | ||
|
|
5a7d615e64 | ||
|
|
1dbf9e4ed4 | ||
|
|
5dcc6ee203 | ||
|
|
84a4e6ae3f | ||
|
|
f283bfd68f | ||
|
|
6e5ff7b79c | ||
|
|
7c3800d03f | ||
|
|
941db90518 | ||
|
|
0d9ecf0f90 | ||
|
|
9c77023a11 | ||
|
|
b55378c63c | ||
|
|
946c2a49ab | ||
|
|
b823c31ec6 | ||
|
|
ec6361e5cb | ||
|
|
0c26d28278 | ||
|
|
c5172d4c5a | ||
|
|
89de04775e | ||
|
|
b4c3c940b5 | ||
|
|
aee2aad959 | ||
|
|
5ca48a8a5f | ||
|
|
1806aa187b | ||
|
|
7824cb7a1a | ||
|
|
9807a896f4 | ||
|
|
19866f057d | ||
|
|
ec4eae3c9c | ||
|
|
bea0cba038 | ||
|
|
48ee75af9c | ||
|
|
929c593d2f | ||
|
|
221f32eca7 | ||
|
|
c3acc15e8b | ||
|
|
1b653278fc | ||
|
|
cc9062ee46 | ||
|
|
91c0feb0ad | ||
|
|
ae60292ac8 | ||
|
|
a6ca17b19d | ||
|
|
6a4a5ece74 | ||
|
|
9b81860307 | ||
|
|
5f4a3928d2 | ||
|
|
b703884763 | ||
|
|
32da98ab8f | ||
|
|
bd5a85bf70 | ||
|
|
d045f24014 | ||
|
|
2aad3f89c3 | ||
|
|
dd54d19f00 | ||
|
|
0ed6591d8c | ||
|
|
712e090134 | ||
|
|
8fc2a1d1cf | ||
|
|
cc15c1593e | ||
|
|
9997d3abda | ||
|
|
031471e785 | ||
|
|
2e860c6791 | ||
|
|
d071a9e17d | ||
|
|
ed53d33321 | ||
|
|
382bc6d978 | ||
|
|
dab42e258c | ||
|
|
81556410bb | ||
|
|
1f2dfd473c | ||
|
|
8f0f51be2c | ||
|
|
7179e250ed | ||
|
|
5bec091fd6 | ||
|
|
2c5896cb0c | ||
|
|
93ff252dc0 | ||
|
|
ac52224455 | ||
|
|
4087cad23d | ||
|
|
e936b1ff8f | ||
|
|
b7f9c5e221 | ||
|
|
fc5467150e | ||
|
|
4dcab357a0 | ||
|
|
695e464255 | ||
|
|
9999b60c3b | ||
|
|
e7df53e260 | ||
|
|
844590a571 | ||
|
|
9622beaa0d | ||
|
|
007e2553a8 | ||
|
|
15ad4e3f5e | ||
|
|
be5094fcb4 | ||
|
|
a20a861680 | ||
|
|
396d0a4bc0 | ||
|
|
ca9314e077 | ||
|
|
4b848798e7 | ||
|
|
083bcbc77d | ||
|
|
e8cdc9ae62 | ||
|
|
8abfa759a4 | ||
|
|
f6a324b633 | ||
|
|
f083be9391 | ||
|
|
091e2fb751 | ||
|
|
d8539daf1f | ||
|
|
7ec059f5fa | ||
|
|
4f2ecdefd2 | ||
|
|
e8891a1988 | ||
|
|
37d2607f34 | ||
|
|
0e7b10d3d9 | ||
|
|
1f85888638 | ||
|
|
c1f9a129fa | ||
|
|
7ccc5ba398 | ||
|
|
5e1a6ae334 | ||
|
|
3f6cf638f9 | ||
|
|
46e062a828 | ||
|
|
cc3a0b5d6c | ||
|
|
775479ab7b | ||
|
|
6b9e0e6d63 | ||
|
|
83a5c87f5e | ||
|
|
84fde74331 | ||
|
|
a517e29b39 | ||
|
|
ccceba7565 | ||
|
|
5fc7a03669 | ||
|
|
8864ad1b50 | ||
|
|
f2989885fb | ||
|
|
19c66e5c76 | ||
|
|
8a6690a57c | ||
|
|
2cc60f253a | ||
|
|
cb69872dd3 | ||
|
|
ba66d7c9a6 | ||
|
|
9fe727c9f8 | ||
|
|
58c656224f | ||
|
|
c51253f5f6 | ||
|
|
6c1d1588fc | ||
|
|
95d6183a6c | ||
|
|
f4c9facdaf | ||
|
|
a274e6f165 | ||
|
|
154e3e6f64 | ||
|
|
2c3ac972e5 | ||
|
|
2e2e072b0b | ||
|
|
9d8dd2bf66 | ||
|
|
9047f6db30 | ||
|
|
5ab345ee63 | ||
|
|
d8a83acd3a | ||
|
|
1f58e5756b | ||
|
|
744acb8f07 | ||
|
|
ae7228d821 | ||
|
|
99d8b3a7bf | ||
|
|
3fbe65bbcf | ||
|
|
f5d879d8e7 | ||
|
|
cbc5a4f8e6 | ||
|
|
37ac7d8ed5 | ||
|
|
bee3fa339d | ||
|
|
c171fe2b96 | ||
|
|
1fa8032fdb | ||
|
|
5c0676bcc2 | ||
|
|
cefd9a027c | ||
|
|
1bce156de1 | ||
|
|
cd4f63f2fd | ||
|
|
3c7140cbf3 | ||
|
|
b71ba63b5a | ||
|
|
d540e2c0d3 | ||
|
|
d79fafc5f5 | ||
|
|
9e93fa2092 | ||
|
|
392e9b4882 |
37
.github/workflows/build-container.yml
vendored
37
.github/workflows/build-container.yml
vendored
@@ -13,12 +13,6 @@ on:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
push-to-registry:
|
||||
description: Push the built image to the container registry
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -56,15 +50,16 @@ jobs:
|
||||
df -h
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
images: |
|
||||
ghcr.io/${{ github.repository }}
|
||||
${{ env.DOCKERHUB_REPOSITORY }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
@@ -77,33 +72,49 @@ jobs:
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
with:
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# - name: Login to Docker Hub
|
||||
# if: github.event_name != 'pull_request' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: docker/login-action@v2
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build container
|
||||
timeout-minutes: 40
|
||||
id: docker_build
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile
|
||||
platforms: ${{ env.PLATFORMS }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
# - name: Docker Hub Description
|
||||
# if: github.ref == 'refs/heads/main' || github.ref == 'refs/tags/*' && vars.DOCKERHUB_REPOSITORY != ''
|
||||
# uses: peter-evans/dockerhub-description@v3
|
||||
# with:
|
||||
# username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
# password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# repository: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
# short-description: ${{ github.event.repository.description }}
|
||||
|
||||
2
.github/workflows/python-tests.yml
vendored
2
.github/workflows/python-tests.yml
vendored
@@ -60,7 +60,7 @@ jobs:
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-14
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
|
||||
@@ -196,22 +196,6 @@ tips to reduce the problem:
|
||||
=== "12GB VRAM GPU"
|
||||
|
||||
This should be sufficient to generate larger images up to about 1280x1280.
|
||||
|
||||
## Checkpoint Models Load Slowly or Use Too Much RAM
|
||||
|
||||
The difference between diffusers models (a folder containing multiple
|
||||
subfolders) and checkpoint models (a file ending with .safetensors or
|
||||
.ckpt) is that InvokeAI is able to load diffusers models into memory
|
||||
incrementally, while checkpoint models must be loaded all at
|
||||
once. With very large models, or systems with limited RAM, you may
|
||||
experience slowdowns and other memory-related issues when loading
|
||||
checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
## Memory Leak (Linux)
|
||||
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
@@ -19,7 +17,6 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
@@ -34,7 +31,6 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
@@ -54,13 +50,6 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class CacheType(str, Enum):
|
||||
"""Cache type - one of vram or ram."""
|
||||
|
||||
RAM = "RAM"
|
||||
VRAM = "VRAM"
|
||||
|
||||
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
@@ -808,83 +797,3 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/model_cache",
|
||||
operation_id="get_cache_size",
|
||||
response_model=float,
|
||||
summary="Get maximum size of model manager RAM or VRAM cache.",
|
||||
)
|
||||
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
|
||||
"""Return the current RAM or VRAM cache size setting (in GB)."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
value = 0.0
|
||||
if cache_type == CacheType.RAM:
|
||||
value = cache.max_cache_size
|
||||
elif cache_type == CacheType.VRAM:
|
||||
value = cache.max_vram_cache_size
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/model_cache",
|
||||
operation_id="set_cache_size",
|
||||
response_model=float,
|
||||
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
|
||||
)
|
||||
async def set_cache_size(
|
||||
value: float = Query(description="The new value for the maximum cache size"),
|
||||
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
|
||||
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
|
||||
) -> float:
|
||||
"""Set the current RAM or VRAM cache size setting (in GB). ."""
|
||||
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
|
||||
app_config = get_config()
|
||||
# Record initial state.
|
||||
vram_old = app_config.vram
|
||||
ram_old = app_config.ram
|
||||
|
||||
# Prepare target state.
|
||||
vram_new = vram_old
|
||||
ram_new = ram_old
|
||||
if cache_type == CacheType.RAM:
|
||||
ram_new = value
|
||||
elif cache_type == CacheType.VRAM:
|
||||
vram_new = value
|
||||
else:
|
||||
raise ValueError(f"Unexpected {cache_type=}.")
|
||||
|
||||
config_path = app_config.config_file_path
|
||||
new_config_path = config_path.with_suffix(".yaml.new")
|
||||
|
||||
try:
|
||||
# Try to apply the target state.
|
||||
cache.max_vram_cache_size = vram_new
|
||||
cache.max_cache_size = ram_new
|
||||
app_config.ram = ram_new
|
||||
app_config.vram = vram_new
|
||||
if persist:
|
||||
app_config.write_file(new_config_path)
|
||||
shutil.move(new_config_path, config_path)
|
||||
except Exception as e:
|
||||
# If there was a failure, restore the initial state.
|
||||
cache.max_cache_size = ram_old
|
||||
cache.max_vram_cache_size = vram_old
|
||||
app_config.ram = ram_old
|
||||
app_config.vram = vram_old
|
||||
|
||||
raise RuntimeError("Failed to update cache size") from e
|
||||
return value
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/stats",
|
||||
operation_id="get_stats",
|
||||
response_model=Optional[CacheStats],
|
||||
summary="Get model manager RAM cache performance statistics.",
|
||||
)
|
||||
async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import semver
|
||||
@@ -79,7 +80,7 @@ class UIConfigBase(BaseModel):
|
||||
version: str = Field(
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
|
||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -229,16 +230,18 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
if title := model_class.UIConfig.title:
|
||||
schema["title"] = title
|
||||
if tags := model_class.UIConfig.tags:
|
||||
schema["tags"] = tags
|
||||
if category := model_class.UIConfig.category:
|
||||
schema["category"] = category
|
||||
if node_pack := model_class.UIConfig.node_pack:
|
||||
schema["node_pack"] = node_pack
|
||||
schema["classification"] = model_class.UIConfig.classification
|
||||
schema["version"] = model_class.UIConfig.version
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
if uiconfig.title is not None:
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig.tags is not None:
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig.category is not None:
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig.node_pack is not None:
|
||||
schema["node_pack"] = uiconfig.node_pack
|
||||
schema["classification"] = uiconfig.classification
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
@@ -309,7 +312,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
UIConfig: ClassVar[UIConfigBase]
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
@@ -438,25 +441,30 @@ def invocation(
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconfig: dict[str, Any] = {}
|
||||
uiconfig["title"] = title
|
||||
uiconfig["tags"] = tags
|
||||
uiconfig["category"] = category
|
||||
uiconfig["classification"] = classification
|
||||
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
cls.UIConfig.classification = classification
|
||||
|
||||
# Grab the node pack's name from the module name, if it's a custom node
|
||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
||||
if is_custom_node:
|
||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
||||
else:
|
||||
cls.UIConfig.node_pack = None
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
uiconfig["version"] = version
|
||||
cls.UIConfig.version = version
|
||||
else:
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
uiconfig["version"] = "1.0.0"
|
||||
|
||||
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||
cls.UIConfig.version = "1.0.0"
|
||||
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
@@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
description=FieldDescriptions.mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
||||
@@ -40,18 +40,14 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
FluxVAEModel = "FluxVAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -129,17 +125,13 @@ class FieldDescriptions:
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -181,7 +173,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
@@ -239,12 +231,6 @@ class ColorField(BaseModel):
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule,
|
||||
generate_img_ids,
|
||||
get_noise,
|
||||
get_schedule,
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
noise = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# If init_latents is provided, we are doing image-to-image.
|
||||
|
||||
if is_schnell:
|
||||
context.logger.warning(
|
||||
"Running image-to-image with a FLUX schnell model. This is not recommended. The results are likely "
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=t5_embeddings,
|
||||
txt_ids=txt_ids,
|
||||
vec=clip_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
- Expands mask to the same shape as latents so that they line up after 'packing'
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. In 'unpacked' format. Used to determine the target shape,
|
||||
device, and dtype for the inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Expand the inpaint mask to the same shape as `latents` so that when we 'pack' `mask` it lines up with
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
|
||||
context.util.flux_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
@@ -1,92 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_encoder",
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
):
|
||||
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_decode",
|
||||
title="FLUX Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,67 +0,0 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_vae_encode",
|
||||
title="FLUX Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l", "flux"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Expose seed parameter at the invocation level.
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -126,7 +126,7 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
title="Tensor Mask to Image",
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.1.0",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Convert a mask tensor to an image."""
|
||||
@@ -135,11 +135,6 @@ class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.tensors.load(self.mask.tensor_name)
|
||||
|
||||
# Squeeze the channel dimension if it exists.
|
||||
if mask.dim() == 3:
|
||||
mask = mask.squeeze(0)
|
||||
|
||||
# Ensure that the mask is binary.
|
||||
if mask.dtype != torch.bool:
|
||||
mask = mask > 0.5
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -13,14 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@@ -67,15 +60,6 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
@@ -138,78 +122,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -12,7 +12,6 @@ from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
@@ -415,17 +414,6 @@ class MaskOutput(BaseInvocationOutput):
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_output")
|
||||
class FluxConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
@@ -88,8 +88,7 @@ class QueueItemEventBase(QueueEventBase):
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the queue item")
|
||||
destination: str | None = Field(default=None, description="The destination of the queue item")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
@@ -115,7 +114,6 @@ class InvocationStartedEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -150,7 +148,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -189,7 +186,6 @@ class InvocationCompleteEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -223,7 +219,6 @@ class InvocationErrorEvent(InvocationEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
@@ -262,7 +257,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
origin=queue_item.origin,
|
||||
destination=queue_item.destination,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
|
||||
@@ -103,7 +103,7 @@ class HFModelSource(StringLikeSource):
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f"::{self.subfolder.as_posix()}"
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@@ -783,9 +783,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
||||
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
||||
path_to_add = Path(f"{top}_{subfolder_rename}")
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
@@ -77,7 +77,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
format: Optional[str] = Field(description="format of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
||||
@@ -77,14 +77,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
||||
|
||||
class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
origin: str | None = Field(default=None, description="The origin of this batch.")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
@@ -203,14 +196,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||
priority: int = Field(default=0, description="The priority of this queue item")
|
||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||
origin: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
|
||||
)
|
||||
destination: str | None = Field(
|
||||
default=None,
|
||||
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
|
||||
)
|
||||
origin: str | None = Field(default=None, description="The origin of this queue item. ")
|
||||
session_id: str = Field(
|
||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||
)
|
||||
@@ -311,7 +297,6 @@ class BatchStatus(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
origin: str | None = Field(..., description="The origin of the batch")
|
||||
destination: str | None = Field(..., description="The destination of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
@@ -458,7 +443,6 @@ class SessionQueueValueToInsert(NamedTuple):
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
origin: str | None
|
||||
destination: str | None
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@@ -480,7 +464,6 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
batch.origin, # origin
|
||||
batch.destination, # destination
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -128,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
@@ -579,8 +579,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
origin
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
@@ -660,7 +659,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
SELECT status, count(*), origin
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -673,7 +672,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
@@ -683,7 +681,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
destination=destination,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
|
||||
@@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -557,24 +557,6 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def flux_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
|
||||
"""
|
||||
The step callback emits a progress event with the current step, the total number of
|
||||
steps, a preview image, and some other internal metadata.
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
flux_step_callback(
|
||||
context_data=self._data,
|
||||
intermediate_state=intermediate_state,
|
||||
events=self._services.events,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -10,11 +10,9 @@ class Migration15Callback:
|
||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN destination TEXT;")
|
||||
|
||||
|
||||
def build_migration_15() -> Migration:
|
||||
@@ -23,7 +21,6 @@ def build_migration_15() -> Migration:
|
||||
|
||||
This migration does the following:
|
||||
- Adds `origin` column to the session queue table.
|
||||
- Adds `destination` column to the session queue table.
|
||||
"""
|
||||
migration_15 = Migration(
|
||||
from_version=14,
|
||||
|
||||
@@ -1,407 +0,0 @@
|
||||
{
|
||||
"name": "FLUX Image to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple image-to-image workflow using a FLUX dev model. ",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "image2image, flux, image-to-image",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend using FLUX dev models for image-to-image workflows. The image-to-image performance with FLUX schnell models is poor.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "denoising_start"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"type": "flux_vae_encode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"image": {
|
||||
"name": "image",
|
||||
"label": "",
|
||||
"value": {
|
||||
"image_name": "8a5c62aa-9335-45d2-9c71-89af9fc1f8d4.png"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 732.7680166609682,
|
||||
"y": -24.37398171806909
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0.04
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1182.8836633018684,
|
||||
"y": -251.38882958913183
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "Model (dev variant recommended for Image-to-Image)"
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "fa23a584-b623-415d-832a-21b5098ff1a1",
|
||||
"hash": "blake3:17c19f0ef941c3b7609a9c94a659ca5364de0be364a91d4179f0e39ba17c3b70",
|
||||
"name": "clip-vit-large-patch14",
|
||||
"base": "any",
|
||||
"type": "clip_embed"
|
||||
}
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "74fc82ba-c0a8-479d-a890-2126f82da758",
|
||||
"hash": "blake3:ce21cb76364aa6e2421311cf4a4b5eb052a76c4f1cd207b50703d8978198a068",
|
||||
"name": "FLUX.1-schnell_ae",
|
||||
"base": "flux",
|
||||
"type": "vae"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 328.1809894659957,
|
||||
"y": -90.2241133566946
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat wearing a birthday hat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 745.8823365057267,
|
||||
"y": -299.60249175851914
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 725.834098928012,
|
||||
"y": 496.2710031089931
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bheight-ace0258f-67d7-4eee-a218-6fff27065214height",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "height",
|
||||
"targetHandle": "height"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912bwidth-ace0258f-67d7-4eee-a218-6fff27065214width",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "width",
|
||||
"targetHandle": "width"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2981a67c-480f-4237-9384-26b68dbf912blatents-ace0258f-67d7-4eee-a218-6fff27065214latents",
|
||||
"type": "default",
|
||||
"source": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-2981a67c-480f-4237-9384-26b68dbf912bvae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "2981a67c-480f-4237-9384-26b68dbf912b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ace0258f-67d7-4eee-a218-6fff27065214latents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-ace0258f-67d7-4eee-a218-6fff27065214seed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-ace0258f-67d7-4eee-a218-6fff27065214transformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-ace0258f-67d7-4eee-a218-6fff27065214positive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "ace0258f-67d7-4eee-a218-6fff27065214",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,326 +0,0 @@
|
||||
{
|
||||
"name": "FLUX Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "A simple text-to-image workflow using FLUX dev or schnell models.",
|
||||
"version": "1.0.4",
|
||||
"contact": "",
|
||||
"tags": "text2image, flux",
|
||||
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "t5_encoder_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "clip_embed_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"fieldName": "vae_model"
|
||||
},
|
||||
{
|
||||
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"fieldName": "num_steps"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"type": "flux_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"denoise_mask": {
|
||||
"name": "denoise_mask",
|
||||
"label": ""
|
||||
},
|
||||
"denoising_start": {
|
||||
"name": "denoising_start",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"name": "denoising_end",
|
||||
"label": "",
|
||||
"value": 1
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_text_conditioning": {
|
||||
"name": "positive_text_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"num_steps": {
|
||||
"name": "num_steps",
|
||||
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||
"value": 30
|
||||
},
|
||||
"guidance": {
|
||||
"name": "guidance",
|
||||
"label": "",
|
||||
"value": 4
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1186.1868226120378,
|
||||
"y": -214.9459927686657
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"type": "flux_vae_decode",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1575.5797431839133,
|
||||
"y": -209.00150975507415
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"type": "flux_model_loader",
|
||||
"version": "1.0.4",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_embed_model": {
|
||||
"name": "clip_embed_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 381.1882713063478,
|
||||
"y": -95.89663532854017
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"type": "flux_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"inputs": {
|
||||
"clip": {
|
||||
"name": "clip",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"t5_max_seq_len": {
|
||||
"name": "t5_max_seq_len",
|
||||
"label": "T5 Max Seq Len",
|
||||
"value": 256
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": "a cat"
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 778.4899149328337,
|
||||
"y": -100.36469216659502
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 800.9667463219505,
|
||||
"y": 285.8297267547506
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-4fe24f07-f906-4f55-ab2c-9beee56ef5bdtransformer",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-4fe24f07-f906-4f55-ab2c-9beee56ef5bdpositive_text_conditioning",
|
||||
"type": "default",
|
||||
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_text_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-4fe24f07-f906-4f55-ab2c-9beee56ef5bdseed",
|
||||
"type": "default",
|
||||
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||
"target": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-4fe24f07-f906-4f55-ab2c-9beee56ef5bdlatents-7e5172eb-48c1-44db-a770-8fd83e1435d1latents",
|
||||
"type": "default",
|
||||
"source": "4fe24f07-f906-4f55-ab2c-9beee56ef5bd",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-7e5172eb-48c1-44db-a770-8fd83e1435d1vae",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "7e5172eb-48c1-44db-a770-8fd83e1435d1",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "max_seq_len",
|
||||
"targetHandle": "t5_max_seq_len"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||
"type": "default",
|
||||
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
|
||||
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -38,25 +38,6 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
[0.0342, -0.0681, -0.0427],
|
||||
[-0.0258, 0.0092, 0.0463],
|
||||
[0.0863, 0.0784, 0.0547],
|
||||
[-0.0017, 0.0402, 0.0158],
|
||||
[0.0501, 0.1058, 0.1152],
|
||||
[-0.0209, -0.0218, -0.0329],
|
||||
[-0.0314, 0.0083, 0.0896],
|
||||
[0.0851, 0.0665, -0.0472],
|
||||
[-0.0534, 0.0238, -0.0024],
|
||||
[0.0452, -0.0026, 0.0048],
|
||||
[0.0892, 0.0831, 0.0881],
|
||||
[-0.1117, -0.0304, -0.0789],
|
||||
[0.0027, -0.0479, -0.0043],
|
||||
[-0.1146, -0.0827, -0.0598],
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
@@ -113,32 +94,3 @@ def stable_diffusion_step_callback(
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
|
||||
def flux_step_callback(
|
||||
context_data: "InvocationContextData",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
events: "EventServiceBase",
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
sample = intermediate_state.latents
|
||||
latent_rgb_factors = torch.tensor(FLUX_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
latent_image_perm = sample.permute(1, 2, 0).to(dtype=sample.dtype, device=sample.device)
|
||||
latent_image = latent_image_perm @ latent_rgb_factors
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF) # change scale from -1..1 to 0..1 # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8)
|
||||
image = Image.fromarray(latents_ubyte.cpu().numpy())
|
||||
(width, height) = image.size
|
||||
width *= 8
|
||||
height *= 8
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
):
|
||||
step = 0
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step,
|
||||
order=1,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
step += 1
|
||||
|
||||
return img
|
||||
@@ -1,35 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with FLUX."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0). In 'packed' format.
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background. In 'packed' format.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents. In 'packed' format.
|
||||
"""
|
||||
assert init_latents.shape == inpaint_mask.shape == noise.shape
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, timestep: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * timestep + (1.0 - timestep) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * self._inpaint_mask + noised_init_latents * (1.0 - self._inpaint_mask)
|
||||
@@ -1,32 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -1,324 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
# Unfortunately, torch.randn_like(...) does not accept a generator argument at the time of writing, so we
|
||||
# have to use torch.randn(...) instead.
|
||||
return mean + std * torch.randn(size=mean.size(), generator=generator, dtype=mean.dtype, device=mean.device)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor, sample: bool = True, generator: torch.Generator | None = None) -> Tensor:
|
||||
"""Run VAE encoding on input tensor x.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input image tensor. Shape: (batch_size, in_channels, height, width).
|
||||
sample (bool, optional): If True, sample from the encoded distribution, else, return the distribution mean.
|
||||
Defaults to True.
|
||||
generator (torch.Generator | None, optional): Optional random number generator for reproducibility.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded latent tensor. Shape: (batch_size, z_channels, latent_height, latent_width).
|
||||
"""
|
||||
|
||||
z = self.reg(self.encoder(x), sample=sample, generator=generator)
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
@@ -1,33 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.is_clip = is_clip
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_module = encoder
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
||||
@@ -1,253 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
@@ -1,135 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def _find_last_index_ge_val(timesteps: list[float], val: float, eps: float = 1e-6) -> int:
|
||||
"""Find the last index in timesteps that is >= val.
|
||||
|
||||
We use epsilon-close equality to avoid potential floating point errors.
|
||||
"""
|
||||
idx = len(list(filter(lambda t: t >= (val - eps), timesteps))) - 1
|
||||
assert idx >= 0
|
||||
return idx
|
||||
|
||||
|
||||
def clip_timestep_schedule(timesteps: list[float], denoising_start: float, denoising_end: float) -> list[float]:
|
||||
"""Clip the timestep schedule to the denoising range.
|
||||
|
||||
Args:
|
||||
timesteps (list[float]): The original timestep schedule: [1.0, ..., 0.0].
|
||||
denoising_start (float): A value in [0, 1] specifying the start of the denoising process. E.g. a value of 0.2
|
||||
would mean that the denoising process start at the last timestep in the schedule >= 0.8.
|
||||
denoising_end (float): A value in [0, 1] specifying the end of the denoising process. E.g. a value of 0.8 would
|
||||
mean that the denoising process end at the last timestep in the schedule >= 0.2.
|
||||
|
||||
Returns:
|
||||
list[float]: The clipped timestep schedule.
|
||||
"""
|
||||
assert 0.0 <= denoising_start <= 1.0
|
||||
assert 0.0 <= denoising_end <= 1.0
|
||||
assert denoising_start <= denoising_end
|
||||
|
||||
t_start_val = 1.0 - denoising_start
|
||||
t_end_val = 1.0 - denoising_end
|
||||
|
||||
t_start_idx = _find_last_index_ge_val(timesteps, t_start_val)
|
||||
t_end_idx = _find_last_index_ge_val(timesteps, t_end_val)
|
||||
|
||||
clipped_timesteps = timesteps[t_start_idx : t_end_idx + 1]
|
||||
|
||||
return clipped_timesteps
|
||||
|
||||
|
||||
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings to latent image."""
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def pack(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattented array of patch embeddings."""
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids.
|
||||
|
||||
Args:
|
||||
h (int): Height of image in latent space.
|
||||
w (int): Width of image in latent space.
|
||||
batch_size (int): Batch size.
|
||||
device (torch.device): Device.
|
||||
dtype (torch.dtype): dtype.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image position ids.
|
||||
"""
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
@@ -1,71 +0,0 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
|
||||
|
||||
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||
"flux-dev": 512,
|
||||
"flux-schnell": 256,
|
||||
}
|
||||
|
||||
|
||||
ae_params = {
|
||||
"flux": AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
params = {
|
||||
"flux-dev": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
"flux-schnell": FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
}
|
||||
@@ -52,7 +52,6 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
@@ -67,9 +66,7 @@ class ModelType(str, Enum):
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
CLIPEmbed = "clip_embed"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
@@ -77,7 +74,6 @@ class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
@@ -108,9 +104,6 @@ class ModelFormat(str, Enum):
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
@@ -193,9 +186,7 @@ class ModelConfigBase(BaseModel):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||
)
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
@@ -214,26 +205,6 @@ class LoRAConfigBase(ModelConfigBase):
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ModelConfigBase):
|
||||
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||
|
||||
|
||||
class T5EncoderConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||
|
||||
|
||||
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@@ -258,6 +229,7 @@ class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@@ -296,6 +268,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@@ -344,21 +317,6 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.format = ModelFormat.BnbQuantizednf4b
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
@@ -392,17 +350,6 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -461,15 +408,12 @@ AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
@@ -477,7 +421,6 @@ AnyModelConfig = Annotated[
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -66,14 +66,12 @@ class ModelLoader(ModelLoaderBase):
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
config.path = str(self._get_model_path(config))
|
||||
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
@@ -85,7 +83,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
|
||||
@@ -128,24 +128,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -210,6 +193,15 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
@@ -24,74 +40,53 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
"""Implementation of ModelCacheBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
@@ -133,16 +128,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
@@ -160,6 +145,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def exists(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> bool:
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
return key in self._cached_models
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
@@ -209,7 +203,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
@@ -237,13 +231,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
@@ -254,7 +245,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
@@ -312,7 +303,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
@@ -335,14 +326,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
@@ -362,20 +353,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
@@ -392,7 +380,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.util.model_util import convert_bundle_to_flux_transformer_checkpoint
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
try:
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
bnb_available = True
|
||||
except ImportError:
|
||||
bnb_available = False
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class FluxVAELoader(ModelLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, VAECheckpointConfig):
|
||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||
class ClipCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer:
|
||||
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
|
||||
case SubModelType.TextEncoder:
|
||||
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
|
||||
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
|
||||
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForTextEncoding.from_config(model_config)
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||
|
||||
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
|
||||
state_dict = load_file(state_dict_path)
|
||||
self._load_state_dict_into_t5(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
|
||||
# There is a shared reference to a single weight tensor in the model.
|
||||
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||
# be present in the state_dict.
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||
# Assert that the layers we expect to be shared are actually shared.
|
||||
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
class T5EncoderCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderConfig):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||
|
||||
raise ValueError(
|
||||
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class FluxCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
model = Flux(params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with SilenceWarnings():
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params[config.config_path])
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||
sd = load_file(model_path)
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@@ -78,12 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in [
|
||||
"diffusers",
|
||||
"transformers",
|
||||
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||
]:
|
||||
if module in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
|
||||
@@ -36,18 +36,8 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
@@ -50,17 +50,6 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
elif isinstance(
|
||||
model,
|
||||
(
|
||||
T5TokenizerFast,
|
||||
T5Tokenizer,
|
||||
),
|
||||
):
|
||||
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
|
||||
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
|
||||
# point.
|
||||
return len(model)
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
|
||||
@@ -95,7 +95,6 @@ class ModelProbe(object):
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"FluxPipeline": ModelType.Main,
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
@@ -107,9 +106,6 @@ class ModelProbe(object):
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -165,7 +161,7 @@ class ModelProbe(object):
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
@@ -180,10 +176,10 @@ class ModelProbe(object):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||
ModelFormat.Checkpoint,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
]:
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
@@ -226,19 +222,7 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
"cond_stage_model.",
|
||||
"first_stage_model.",
|
||||
"model.diffusion_model.",
|
||||
# FLUX models in the official BFL format contain keys with the "double_blocks." prefix.
|
||||
"double_blocks.",
|
||||
# Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model".
|
||||
# This prefix is typically used to distinguish between multiple models bundled in a single file.
|
||||
"model.diffusion_model.double_blocks.",
|
||||
)
|
||||
):
|
||||
# Keys starting with double_blocks are associated with Flux models
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
return ModelType.Main
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
@@ -296,16 +280,9 @@ class ModelProbe(object):
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
config_path = None
|
||||
for p in [
|
||||
folder_path / "model_index.json", # pipeline
|
||||
folder_path / "config.json", # most diffusers
|
||||
folder_path / "text_encoder_2" / "config.json", # T5 text encoder
|
||||
folder_path / "text_encoder" / "config.json", # T5 CLIP
|
||||
]:
|
||||
if p.exists():
|
||||
config_path = p
|
||||
break
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
@@ -344,30 +321,10 @@ class ModelProbe(object):
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
if base_type == BaseModelType.Flux:
|
||||
# TODO: Decide between dev/schnell
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
if (
|
||||
"guidance_in.out_layer.weight" in state_dict
|
||||
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
|
||||
):
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-dev"
|
||||
else:
|
||||
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||
# Due to model type and format being the discriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"controlnet/cldm_v15.yaml"
|
||||
@@ -376,13 +333,7 @@ class ModelProbe(object):
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||
# Due to model type and format being the descriminator for model configs this
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
"flux"
|
||||
if base_type is BaseModelType.Flux
|
||||
else "stable-diffusion/v1-inference.yaml"
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "stable-diffusion/sd_xl_base.yaml"
|
||||
if base_type is BaseModelType.StableDiffusionXL
|
||||
@@ -465,18 +416,11 @@ class CheckpointProbeBase(ProbeBase):
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
if (
|
||||
"double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict
|
||||
):
|
||||
return ModelFormat.BnbQuantizednf4b
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
@@ -496,11 +440,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
if (
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
@@ -543,7 +482,6 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
@@ -775,30 +713,6 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
path = self.model_path / "text_encoder_2"
|
||||
if (path / "model.safetensors.index.json").exists():
|
||||
return ModelFormat.T5Encoder
|
||||
files = list(path.glob("*.safetensors"))
|
||||
if len(files) == 0:
|
||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found")
|
||||
|
||||
# shortcut: look for the quantization in the name
|
||||
if any(x for x in files if "llm_int8" in x.as_posix()):
|
||||
return ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
# more reliable path: probe contents for a 'SCB' key
|
||||
ckpt = read_checkpoint_meta(files[0], scan=True)
|
||||
if any("SCB" in x for x in ckpt.keys()):
|
||||
return ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format")
|
||||
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# Due to the way the installer is set up, the configuration file for safetensors
|
||||
@@ -891,11 +805,6 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class CLIPEmbedFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
@@ -926,10 +835,8 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
class StarterModelWithoutDependencies(BaseModel):
|
||||
@@ -11,7 +11,6 @@ class StarterModelWithoutDependencies(BaseModel):
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
format: Optional[ModelFormat] = None
|
||||
is_installed: bool = False
|
||||
|
||||
|
||||
@@ -52,76 +51,10 @@ cyberrealistic_negative = StarterModel(
|
||||
type=ModelType.TextualInversion,
|
||||
)
|
||||
|
||||
t5_base_encoder = StarterModel(
|
||||
name="t5_base_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bfloat16",
|
||||
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
||||
type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
t5_8b_quantized_encoder = StarterModel(
|
||||
name="t5_bnb_int8_quantized_encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
|
||||
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
|
||||
type=ModelType.T5Encoder,
|
||||
format=ModelFormat.BnbQuantizedLlmInt8b,
|
||||
)
|
||||
|
||||
clip_l_encoder = StarterModel(
|
||||
name="clip-vit-large-patch14",
|
||||
base=BaseModelType.Any,
|
||||
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
|
||||
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
|
||||
type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
flux_vae = StarterModel(
|
||||
name="FLUX.1-schnell_ae",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
||||
description="FLUX VAE compatible with both schnell and dev variants.",
|
||||
type=ModelType.VAE,
|
||||
)
|
||||
|
||||
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
# region: Main
|
||||
StarterModel(
|
||||
name="FLUX Schnell (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
|
||||
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
|
||||
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Schnell",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
|
||||
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX Dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
|
||||
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
),
|
||||
StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -192,7 +125,6 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region VAE
|
||||
sdxl_fp16_vae_fix,
|
||||
flux_vae,
|
||||
# endregion
|
||||
# region LoRA
|
||||
StarterModel(
|
||||
@@ -518,11 +450,6 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
# endregion
|
||||
# region TextEncoders
|
||||
t5_base_encoder,
|
||||
t5_8b_quantized_encoder,
|
||||
clip_l_encoder,
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
@@ -133,29 +133,3 @@ def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[in
|
||||
break
|
||||
|
||||
return lora_token_vector_length
|
||||
|
||||
|
||||
def convert_bundle_to_flux_transformer_checkpoint(
|
||||
transformer_state_dict: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
original_state_dict: dict[str, torch.Tensor] = {}
|
||||
keys_to_remove: list[str] = []
|
||||
|
||||
for k, v in transformer_state_dict.items():
|
||||
if not k.startswith("model.diffusion_model"):
|
||||
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
|
||||
continue
|
||||
if k.endswith("scale"):
|
||||
# Scale math must be done at bfloat16 due to our current flux model
|
||||
# support limitations at inference time
|
||||
v = v.to(dtype=torch.bfloat16)
|
||||
new_key = k.replace("model.diffusion_model.", "")
|
||||
original_state_dict[new_key] = v
|
||||
keys_to_remove.append(k)
|
||||
|
||||
# Remove processed keys from the original dictionary, leaving others in case
|
||||
# other model state dicts need to be pulled
|
||||
for k in keys_to_remove:
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
@@ -54,7 +54,6 @@ def filter_files(
|
||||
"lora_weights.safetensors",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||
)
|
||||
):
|
||||
paths.append(file)
|
||||
@@ -63,13 +62,13 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
if subfolder:
|
||||
subfolder = root / subfolder
|
||||
paths = [x for x in paths if Path(subfolder) in x.parents]
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
@@ -98,9 +97,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
# Note: '.model' was added to support:
|
||||
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif variant in [
|
||||
@@ -143,23 +140,6 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
continue
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
if not at_least_one_fp16:
|
||||
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||
for candidate in candidate_list:
|
||||
result.add(candidate.path)
|
||||
|
||||
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeInt8Params(bnb.nn.Int8Params):
|
||||
"""We override cuda() to avoid re-quantizing the weights in the following cases:
|
||||
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
|
||||
- We are moving the model back-and-forth between the cpu and gpu.
|
||||
"""
|
||||
|
||||
def cuda(self, device):
|
||||
if self.has_fp16_weights:
|
||||
return super().cuda(device)
|
||||
elif self.CB is not None and self.SCB is not None:
|
||||
self.data = self.data.cuda()
|
||||
self.CB = self.data
|
||||
self.SCB = self.SCB.cuda()
|
||||
else:
|
||||
# we store the 8-bit rows-major weight
|
||||
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
self.CB = CB
|
||||
self.SCB = SCB
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
|
||||
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||
scb = state_dict.pop(prefix + "SCB", None)
|
||||
|
||||
# Currently, we only support weight_format=0.
|
||||
weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||
assert weight_format == 0
|
||||
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert len(state_dict) == 0
|
||||
|
||||
if scb is not None:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
# Note: After quantization, CB is the same as weight.
|
||||
CB=weight,
|
||||
SCB=scb,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = InvokeInt8Params(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
has_fp16_weights=False,
|
||||
CB=None,
|
||||
SCB=None,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
# Reset the state. The persisted fields are based on the initialization behaviour in
|
||||
# `bnb.nn.Linear8bitLt.__init__()`.
|
||||
new_state = bnb.MatmulLtState()
|
||||
new_state.threshold = self.state.threshold
|
||||
new_state.has_fp16_weights = False
|
||||
new_state.use_pool = self.state.use_pool
|
||||
self.state = new_state
|
||||
|
||||
|
||||
def _convert_linear_layers_to_llm_8bit(
|
||||
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||
) -> None:
|
||||
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinear8bitLt(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=outlier_threshold,
|
||||
)
|
||||
replacement.weight.data = child.weight.data
|
||||
if has_bias:
|
||||
replacement.bias.data = child.bias.data
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||
_convert_linear_layers_to_llm_8bit(
|
||||
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -1,156 +0,0 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
|
||||
# The utils in this file are partially inspired by:
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||
|
||||
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||
|
||||
|
||||
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
|
||||
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
|
||||
"""
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||
"""
|
||||
weight = state_dict.pop(prefix + "weight")
|
||||
bias = state_dict.pop(prefix + "bias", None)
|
||||
# We expect the remaining keys to be quant_state keys.
|
||||
quant_state_sd = state_dict
|
||||
|
||||
# During serialization, the quant_state is stored as subkeys of "weight." (See
|
||||
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
|
||||
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||
# rather than raising an exception to correctly implement this API.
|
||||
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||
|
||||
if len(quant_state_sd) > 0:
|
||||
# We are loading a pre-quantized state dict.
|
||||
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||
else:
|
||||
# We are loading a non-quantized state dict.
|
||||
|
||||
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||
self.weight = bnb.nn.Params4bit(
|
||||
data=weight,
|
||||
requires_grad=self.weight.requires_grad,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
)
|
||||
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||
|
||||
|
||||
def _replace_param(
|
||||
param: torch.nn.Parameter | bnb.nn.Params4bit,
|
||||
data: torch.Tensor,
|
||||
) -> torch.nn.Parameter:
|
||||
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
|
||||
the "meta" device.
|
||||
|
||||
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
|
||||
"""
|
||||
if param.device.type == "meta":
|
||||
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
|
||||
# re-create the param instead of overwriting the data.
|
||||
if isinstance(param, bnb.nn.Params4bit):
|
||||
return bnb.nn.Params4bit(
|
||||
data,
|
||||
requires_grad=data.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
)
|
||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||
|
||||
param.data = data
|
||||
return param
|
||||
|
||||
|
||||
def _convert_linear_layers_to_nf4(
|
||||
module: torch.nn.Module,
|
||||
ignore_modules: set[str],
|
||||
compute_dtype: torch.dtype,
|
||||
compress_statistics: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""Convert all linear layers in the model to NF4 quantized linear layers.
|
||||
|
||||
Args:
|
||||
module: All linear layers in this module will be converted.
|
||||
ignore_modules: A set of module prefixes to ignore when converting linear layers.
|
||||
compute_dtype: The dtype to use for computation in the quantized linear layers.
|
||||
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
|
||||
constants from the first quantization are quantized again.
|
||||
prefix: The prefix of the current module in the model. Used to call this function recursively.
|
||||
"""
|
||||
for name, child in module.named_children():
|
||||
fullname = f"{prefix}.{name}" if prefix else name
|
||||
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||
has_bias = child.bias is not None
|
||||
replacement = InvokeLinearNF4(
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=has_bias,
|
||||
compute_dtype=compute_dtype,
|
||||
compress_statistics=compress_statistics,
|
||||
)
|
||||
if has_bias:
|
||||
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||
replacement.weight = _replace_param(replacement.weight, child.weight.data)
|
||||
replacement.requires_grad_(False)
|
||||
module.__setattr__(name, replacement)
|
||||
else:
|
||||
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||
|
||||
|
||||
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||
"""Apply bitsandbytes nf4 quantization to the model.
|
||||
|
||||
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
# Initialize the model from a config on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = ModelClass.from_config(...)
|
||||
|
||||
# Add NF4 quantization linear layers to the model - still on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
|
||||
|
||||
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
|
||||
# place.
|
||||
model.to("cuda")
|
||||
```
|
||||
"""
|
||||
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||
|
||||
return model
|
||||
@@ -1,79 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
|
||||
def main():
|
||||
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
|
||||
|
||||
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||
etc.) are hardcoded and would need to be modified for other use cases.
|
||||
"""
|
||||
# Load the FLUX transformer model onto the meta device.
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(p)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path)
|
||||
model.load_state_dict(sd, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_int8_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,96 +0,0 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_time(name: str):
|
||||
"""Helper context manager to log the time taken by a block of code."""
|
||||
start = time.time()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
end = time.time()
|
||||
print(f"'{name}' took {end - start:.4f} secs")
|
||||
|
||||
|
||||
def main():
|
||||
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
|
||||
|
||||
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||
etc.) are hardcoded and would need to be modified for other use cases.
|
||||
"""
|
||||
model_path = Path(
|
||||
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||
)
|
||||
|
||||
# inference_dtype = torch.bfloat16
|
||||
with log_time("Intialize FLUX transformer on meta device"):
|
||||
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||
p = params["flux-schnell"]
|
||||
|
||||
# Initialize the model on the "meta" device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(p)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||
if model_nf4_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_nf4_path)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_nf4(
|
||||
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
state_dict = load_file(model_path)
|
||||
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(model.state_dict(), model_nf4_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||
|
||||
assert isinstance(model, Flux)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,92 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
||||
|
||||
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||
|
||||
|
||||
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
|
||||
# There is a shared reference to a single weight tensor in the model.
|
||||
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||
# be present in the state_dict.
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
assert len(unexpected_keys) == 0
|
||||
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||
# Assert that the layers we expect to be shared are actually shared.
|
||||
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||
|
||||
|
||||
def main():
|
||||
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
|
||||
|
||||
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||
etc.) are hardcoded and would need to be modified for other use cases.
|
||||
"""
|
||||
model_path = Path("/data/misc/text_encoder_2")
|
||||
|
||||
with log_time("Intialize T5 on meta device"):
|
||||
model_config = AutoConfig.from_pretrained(model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForTextEncoding.from_config(model_config)
|
||||
|
||||
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||
modules_to_not_convert: set[str] = set()
|
||||
|
||||
model_int8_path = model_path / "bnb_llm_int8.safetensors"
|
||||
if model_int8_path.exists():
|
||||
# The quantized model already exists, load it and return it.
|
||||
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||
|
||||
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
sd = load_file(model_int8_path)
|
||||
load_state_dict_into_t5(model, sd)
|
||||
|
||||
with log_time("Move model to cuda"):
|
||||
model = model.to("cuda")
|
||||
|
||||
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||
|
||||
else:
|
||||
# The quantized model does not exist, quantize the model and save it.
|
||||
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||
|
||||
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||
|
||||
with log_time("Load state dict into model"):
|
||||
# Load sharded state dict.
|
||||
files = list(model_path.glob("*.safetensors"))
|
||||
state_dict = {}
|
||||
for file in files:
|
||||
sd = load_file(file)
|
||||
state_dict.update(sd)
|
||||
load_state_dict_into_t5(model, state_dict)
|
||||
|
||||
with log_time("Move model to cuda and quantize"):
|
||||
model = model.to("cuda")
|
||||
|
||||
with log_time("Save quantized model"):
|
||||
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
state_dict = model.state_dict()
|
||||
state_dict.pop("encoder.embed_tokens.weight")
|
||||
save_file(state_dict, model_int8_path)
|
||||
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
|
||||
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
|
||||
# save_model(model, model_int8_path)
|
||||
|
||||
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||
|
||||
assert isinstance(model, T5EncoderModel)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,6 +25,11 @@ class BasicConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
@@ -38,22 +43,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLUXConditioningInfo:
|
||||
clip_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
|
||||
@@ -3,9 +3,10 @@ Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import Chdir, directory_size
|
||||
from invokeai.backend.util.util import GIG, Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
|
||||
@@ -7,6 +7,9 @@ from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
|
||||
@@ -64,7 +64,6 @@
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"compare-versions": "^6.1.1",
|
||||
"dateformat": "^5.0.3",
|
||||
"fracturedjsonjs": "^4.0.2",
|
||||
@@ -93,6 +92,7 @@
|
||||
"react-icons": "^5.2.1",
|
||||
"react-redux": "9.1.2",
|
||||
"react-resizable-panels": "^2.0.23",
|
||||
"react-select": "5.8.0",
|
||||
"react-use": "^17.5.1",
|
||||
"react-virtuoso": "^4.9.0",
|
||||
"reactflow": "^11.11.4",
|
||||
@@ -136,7 +136,6 @@
|
||||
"@vitest/coverage-v8": "^1.5.0",
|
||||
"@vitest/ui": "^1.5.0",
|
||||
"concurrently": "^8.2.2",
|
||||
"csstype": "^3.1.3",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-plugin-i18next": "^6.0.9",
|
||||
|
||||
332
invokeai/frontend/web/pnpm-lock.yaml
generated
332
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -41,9 +41,6 @@ dependencies:
|
||||
chakra-react-select:
|
||||
specifier: ^4.9.1
|
||||
version: 4.9.1(@chakra-ui/form-control@2.2.0)(@chakra-ui/icon@3.2.0)(@chakra-ui/layout@2.3.1)(@chakra-ui/media-query@3.3.0)(@chakra-ui/menu@2.2.1)(@chakra-ui/spinner@2.1.0)(@chakra-ui/system@2.6.2)(@emotion/react@11.13.3)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
cmdk:
|
||||
specifier: ^1.0.0
|
||||
version: 1.0.0(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
compare-versions:
|
||||
specifier: ^6.1.1
|
||||
version: 6.1.1
|
||||
@@ -128,6 +125,9 @@ dependencies:
|
||||
react-resizable-panels:
|
||||
specifier: ^2.0.23
|
||||
version: 2.0.23(react-dom@18.3.1)(react@18.3.1)
|
||||
react-select:
|
||||
specifier: 5.8.0
|
||||
version: 5.8.0(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
react-use:
|
||||
specifier: ^17.5.1
|
||||
version: 17.5.1(react-dom@18.3.1)(react@18.3.1)
|
||||
@@ -238,9 +238,6 @@ devDependencies:
|
||||
concurrently:
|
||||
specifier: ^8.2.2
|
||||
version: 8.2.2
|
||||
csstype:
|
||||
specifier: ^3.1.3
|
||||
version: 3.1.3
|
||||
dpdm:
|
||||
specifier: ^3.14.0
|
||||
version: 3.14.0
|
||||
@@ -2056,7 +2053,7 @@ packages:
|
||||
dependencies:
|
||||
'@chakra-ui/dom-utils': 2.1.0
|
||||
react: 18.3.1
|
||||
react-focus-lock: 2.13.0(@types/react@18.3.3)(react@18.3.1)
|
||||
react-focus-lock: 2.12.1(@types/react@18.3.3)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
dev: false
|
||||
@@ -3787,288 +3784,6 @@ packages:
|
||||
resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==}
|
||||
dev: false
|
||||
|
||||
/@radix-ui/primitive@1.0.1:
|
||||
resolution: {integrity: sha512-yQ8oGX2GVsEYMWGxcovu1uGWPCxV5BFfeeYxqPmuAzUyLT9qmaMXSAhXpb0WrspIeqYzdJpkh2vHModJPgRIaw==}
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-compose-refs@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-fDSBgd44FKHa1FRMU59qBMPFcl2PZE+2nmqunj+BWFyYYjnhIDWL2ItDs3rrbJDQOtzt5nIebLCQc4QRfz6LJw==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-context@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ebbrdFoYTcuZ0v4wG5tedGnp9tzcV8awzsxYph7gXUyvnNLuTIcCk1q17JEbnVhXAKG9oX3KtchwiMIAYp9NLg==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-dialog@1.0.5(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-GjWJX/AUpB703eEBanuBnIWdIXg6NvJFCXcNlSZk4xdszCdhrJgBoUd1cGk67vFO+WdA2pfI/plOpqz/5GUP6Q==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/primitive': 1.0.1
|
||||
'@radix-ui/react-compose-refs': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-context': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-dismissable-layer': 1.0.5(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-focus-guards': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-focus-scope': 1.0.4(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-id': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-portal': 1.0.4(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-presence': 1.0.1(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-slot': 1.0.2(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-use-controllable-state': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
'@types/react-dom': 18.3.0
|
||||
aria-hidden: 1.2.4
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
react-remove-scroll: 2.5.5(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-dismissable-layer@1.0.5(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-aJeDjQhywg9LBu2t/At58hCvr7pEm0o2Ke1x33B+MhjNmmZ17sy4KImo0KPLgsnc/zN7GPdce8Cnn0SWvwZO7g==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/primitive': 1.0.1
|
||||
'@radix-ui/react-compose-refs': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-use-escape-keydown': 1.0.3(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
'@types/react-dom': 18.3.0
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-focus-guards@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-Rect2dWbQ8waGzhMavsIbmSVCgYxkXLxxR3ZvCX79JOglzdEy4JXMb98lq4hPxUbLr77nP0UOGf4rcMU+s1pUA==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-focus-scope@1.0.4(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-sL04Mgvf+FmyvZeYfNu1EPAaaxD+aw7cYeIB9L9Fvq8+urhltTRaEo5ysKOpHuKPclsZcSUMKlN05x4u+CINpA==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-compose-refs': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
'@types/react-dom': 18.3.0
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-id@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-tI7sT/kqYp8p96yGWY1OAnLHrqDgzHefRBKQ2YAkBS5ja7QLcZ9Z/uY7bEjPUatf8RomoXM8/1sMj1IJaE5UzQ==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-portal@1.0.4(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-Qki+C/EuGUVCQTOTD5vzJzJuMUlewbzuKyUy+/iHM2uwGiru9gZeBJtHAPKAEkB5KWGi9mP/CHKcY0wt1aW45Q==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
'@types/react-dom': 18.3.0
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-presence@1.0.1(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-UXLW4UAbIY5ZjcvzjfRFo5gxva8QirC9hF7wRE4U5gz+TP0DbRk+//qyuAQ1McDxBt1xNMBTaciFGvEmJvAZCg==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-compose-refs': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@radix-ui/react-use-layout-effect': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
'@types/react-dom': 18.3.0
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-primitive@1.0.3(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-yi58uVyoAcK/Nq1inRY56ZSjKypBNKTa/1mcL8qdl6oJeEaDbOldlzrGn7P6Q3Id5d+SYNGc5AJgc4vGhjs5+g==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
'@types/react-dom': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
react-dom: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
'@types/react-dom':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-slot': 1.0.2(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
'@types/react-dom': 18.3.0
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-slot@1.0.2(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-compose-refs': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-use-callback-ref@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-D94LjX4Sp0xJFVaoQOd3OO9k7tpBYNOXdVhkltUbGv2Qb9OXdrg/CpsjlZv7ia14Sylv398LswWBVVu5nqKzAQ==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-use-controllable-state@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-Svl5GY5FQeN758fWKrjM6Qb7asvXeiZltlT4U2gVfl8Gx5UAv2sMR0LWo8yhsIZh2oQ0eFdZ59aoOOMV7b47VA==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-use-escape-keydown@1.0.3(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-vyL82j40hcFicA+M4Ex7hVkB9vHgSse1ZWomAqV2Je3RleKGO5iM8KMOEtfoSB0PnIelMd2lATjTGMYqN5ylTg==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@radix-ui/react-use-callback-ref': 1.0.1(@types/react@18.3.3)(react@18.3.1)
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@radix-ui/react-use-layout-effect@1.0.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-v/5RegiJWYdoCvMnITBkNNx6bCj20fiaJnWtRkU18yITptraXjffz5Qbn05uOiQnOvi+dbkznkoaMltz1GnszQ==}
|
||||
peerDependencies:
|
||||
'@types/react': '*'
|
||||
react: ^16.8 || ^17.0 || ^18.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@babel/runtime': 7.25.4
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/@reactflow/background@11.3.14(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-Gewd7blEVT5Lh6jqrvOgd4G6Qk17eGKQfsDXgyRSqM+CTwDqRldG2LsWN4sNeno6sbqVIC2fZ+rAUBFA9ZEUDA==}
|
||||
peerDependencies:
|
||||
@@ -5495,6 +5210,7 @@ packages:
|
||||
resolution: {integrity: sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==}
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
dev: true
|
||||
|
||||
/@types/react-transition-group@4.4.10:
|
||||
resolution: {integrity: sha512-hT/+s0VQs2ojCX823m60m5f0sL5idt9SO6Tj6Dg+rdphGPIeJbJ6CxvBYkgkGKrYeDjvIpKTR38UzmtHJOGW3Q==}
|
||||
@@ -6552,21 +6268,6 @@ packages:
|
||||
requiresBuild: true
|
||||
dev: true
|
||||
|
||||
/cmdk@1.0.0(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-gDzVf0a09TvoJ5jnuPvygTB77+XdOSwEmJ88L6XPFPlv7T3RxbP9jgenfylrAMD0+Le1aO0nVjQUzl2g+vjz5Q==}
|
||||
peerDependencies:
|
||||
react: ^18.0.0
|
||||
react-dom: ^18.0.0
|
||||
dependencies:
|
||||
'@radix-ui/react-dialog': 1.0.5(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
'@radix-ui/react-primitive': 1.0.3(@types/react-dom@18.3.0)(@types/react@18.3.3)(react-dom@18.3.1)(react@18.3.1)
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
- '@types/react-dom'
|
||||
dev: false
|
||||
|
||||
/color-convert@1.9.3:
|
||||
resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==}
|
||||
dependencies:
|
||||
@@ -10011,8 +9712,8 @@ packages:
|
||||
resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==}
|
||||
dev: false
|
||||
|
||||
/react-focus-lock@2.13.0(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-w7aIcTwZwNzUp2fYQDMICy+khFwVmKmOrLF8kNsPS+dz4Oi/oxoVJ2wCMVvX6rWGriM/+mYaTyp1MRmkcs2amw==}
|
||||
/react-focus-lock@2.12.1(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-lfp8Dve4yJagkHiFrC1bGtib3mF2ktqwPJw4/WGcgPW+pJ/AVQA5X2vI7xgp13FcxFEpYBBHpXai/N2DBNC0Jw==}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
@@ -10166,25 +9867,6 @@ packages:
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-remove-scroll@2.5.5(@types/react@18.3.3)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
'@types/react': ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0
|
||||
peerDependenciesMeta:
|
||||
'@types/react':
|
||||
optional: true
|
||||
dependencies:
|
||||
'@types/react': 18.3.3
|
||||
react: 18.3.1
|
||||
react-remove-scroll-bar: 2.3.6(@types/react@18.3.3)(react@18.3.1)
|
||||
react-style-singleton: 2.2.1(@types/react@18.3.3)(react@18.3.1)
|
||||
tslib: 2.7.0
|
||||
use-callback-ref: 1.3.2(@types/react@18.3.3)(react@18.3.1)
|
||||
use-sidecar: 1.1.2(@types/react@18.3.3)(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
/react-resizable-panels@2.0.23(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-8ZKTwTU11t/FYwiwhMdtZYYyFxic5U5ysRu2YwfkAgDbUJXFvnWSJqhnzkSlW+mnDoNAzDCrJhdOSXBPA76wug==}
|
||||
peerDependencies:
|
||||
|
||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
@@ -164,10 +164,10 @@
|
||||
"alpha": "Alpha",
|
||||
"selected": "Selected",
|
||||
"tab": "Tab",
|
||||
"view": "View",
|
||||
"viewDesc": "Review images in a large gallery view",
|
||||
"edit": "Edit",
|
||||
"editDesc": "Edit on the Canvas",
|
||||
"viewing": "Viewing",
|
||||
"viewingDesc": "Review images in a large gallery view",
|
||||
"editing": "Editing",
|
||||
"editingDesc": "Edit on the Control Layers canvas",
|
||||
"comparing": "Comparing",
|
||||
"comparingDesc": "Comparing two images",
|
||||
"enabled": "Enabled",
|
||||
@@ -328,13 +328,9 @@
|
||||
"completedIn": "Completed in",
|
||||
"batch": "Batch",
|
||||
"origin": "Origin",
|
||||
"destination": "Destination",
|
||||
"upscaling": "Upscaling",
|
||||
"canvas": "Canvas",
|
||||
"generation": "Generation",
|
||||
"workflows": "Workflows",
|
||||
"other": "Other",
|
||||
"gallery": "Gallery",
|
||||
"originCanvas": "Canvas",
|
||||
"originWorkflows": "Workflows",
|
||||
"originOther": "Other",
|
||||
"batchFieldValues": "Batch Field Values",
|
||||
"item": "Item",
|
||||
"session": "Session",
|
||||
@@ -706,8 +702,6 @@
|
||||
"availableModels": "Available Models",
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipVision": "CLIP Vision",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@@ -795,16 +789,13 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"spandrelImageToImage": "Image to Image (Spandrel)",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"syncModels": "Sync Models",
|
||||
"textualInversions": "Textual Inversions",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
"typePhraseHere": "Type phrase here",
|
||||
"t5Encoder": "T5 Encoder",
|
||||
"upcastAttention": "Upcast Attention",
|
||||
"uploadImage": "Upload Image",
|
||||
"urlOrLocalPath": "URL or Local Path",
|
||||
@@ -1654,16 +1645,6 @@
|
||||
"storeNotInitialized": "Store is not initialized"
|
||||
},
|
||||
"controlLayers": {
|
||||
"bookmark": "Bookmark for Quick Switch",
|
||||
"removeBookmark": "Remove Bookmark",
|
||||
"saveCanvasToGallery": "Save Canvas To Gallery",
|
||||
"saveBboxToGallery": "Save Bbox To Gallery",
|
||||
"savedToGalleryOk": "Saved to Gallery",
|
||||
"savedToGalleryError": "Error saving to gallery",
|
||||
"mergeVisible": "Merge Visible",
|
||||
"mergeVisibleOk": "Merged visible layers",
|
||||
"mergeVisibleError": "Error merging visible layers",
|
||||
"clearHistory": "Clear History",
|
||||
"generateMode": "Generate",
|
||||
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||
"composeMode": "Compose",
|
||||
@@ -1671,10 +1652,10 @@
|
||||
"autoSave": "Auto-save to Gallery",
|
||||
"resetCanvas": "Reset Canvas",
|
||||
"resetAll": "Reset All",
|
||||
"deleteAll": "Delete All",
|
||||
"clearCaches": "Clear Caches",
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"compositeMaskedRegions": "Composite Masked Regions",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -1693,47 +1674,35 @@
|
||||
"deletePrompt": "Delete Prompt",
|
||||
"resetRegion": "Reset Region",
|
||||
"debugLayers": "Debug Layers",
|
||||
"showHUD": "Show HUD",
|
||||
"rectangle": "Rectangle",
|
||||
"maskFill": "Mask Fill",
|
||||
"maskPreviewColor": "Mask Preview Color",
|
||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||
"addRasterLayer": "Add $t(controlLayers.rasterLayer)",
|
||||
"addControlLayer": "Add $t(controlLayers.controlLayer)",
|
||||
"addInpaintMask": "Add $t(controlLayers.inpaintMask)",
|
||||
"addRegionalGuidance": "Add $t(controlLayers.regionalGuidance)",
|
||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||
"raster": "Raster",
|
||||
"rasterLayer": "Raster Layer",
|
||||
"controlLayer": "Control Layer",
|
||||
"inpaintMask": "Inpaint Mask",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"ipAdapter": "IP Adapter",
|
||||
"sendToGallery": "Send To Gallery",
|
||||
"sendToGalleryDesc": "Generations will be sent to the gallery.",
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
"sendToCanvasDesc": "Generations will be staged onto the canvas.",
|
||||
"rasterLayer_withCount_one": "$t(controlLayers.rasterLayer)",
|
||||
"controlLayer_withCount_one": "$t(controlLayers.controlLayer)",
|
||||
"inpaintMask_withCount_one": "$t(controlLayers.inpaintMask)",
|
||||
"regionalGuidance_withCount_one": "$t(controlLayers.regionalGuidance)",
|
||||
"ipAdapter_withCount_one": "$t(controlLayers.ipAdapter)",
|
||||
"rasterLayer_withCount_other": "Raster Layers",
|
||||
"controlLayer_withCount_other": "Control Layers",
|
||||
"inpaintMask_withCount_other": "Inpaint Masks",
|
||||
"regionalGuidance_withCount_other": "Regional Guidance",
|
||||
"ipAdapter_withCount_other": "IP Adapters",
|
||||
"rasterLayer_one": "Raster Layer",
|
||||
"controlLayer_one": "Control Layer",
|
||||
"inpaintMask_one": "Inpaint Mask",
|
||||
"regionalGuidance_one": "Regional Guidance",
|
||||
"ipAdapter_one": "IP Adapter",
|
||||
"rasterLayer_other": "Raster Layers",
|
||||
"controlLayer_other": "Control Layers",
|
||||
"inpaintMask_other": "Inpaint Masks",
|
||||
"regionalGuidance_other": "Regional Guidance",
|
||||
"ipAdapter_other": "IP Adapters",
|
||||
"opacity": "Opacity",
|
||||
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||
"controlAdapters_withCount_hidden": "Control Adapters ({{count}} hidden)",
|
||||
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||
"globalIPAdapters_withCount_hidden": "Global IP Adapters ({{count}} hidden)",
|
||||
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
|
||||
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||
"controlAdapters_withCount_visible": "Control Adapters ({{count}})",
|
||||
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||
"globalIPAdapters_withCount_visible": "Global IP Adapters ({{count}})",
|
||||
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
|
||||
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
|
||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||
@@ -1746,8 +1715,8 @@
|
||||
"clearProcessor": "Clear Processor",
|
||||
"resetProcessor": "Reset Processor to Defaults",
|
||||
"noLayersAdded": "No Layers Added",
|
||||
"layer_one": "Layer",
|
||||
"layer_other": "Layers",
|
||||
"layers_one": "Layer",
|
||||
"layers_other": "Layers",
|
||||
"objects_zero": "empty",
|
||||
"objects_one": "{{count}} object",
|
||||
"objects_other": "{{count}} objects",
|
||||
@@ -1760,14 +1729,7 @@
|
||||
"showingType": "Showing {{type}}",
|
||||
"dynamicGrid": "Dynamic Grid",
|
||||
"logDebugInfo": "Log Debug Info",
|
||||
"locked": "Locked",
|
||||
"unlocked": "Unlocked",
|
||||
"deleteSelected": "Delete Selected",
|
||||
"deleteAll": "Delete All",
|
||||
"flipHorizontal": "Flip Horizontal",
|
||||
"flipVertical": "Flip Vertical",
|
||||
"fill": {
|
||||
"fillColor": "Fill Color",
|
||||
"fillStyle": "Fill Style",
|
||||
"solid": "Solid",
|
||||
"grid": "Grid",
|
||||
@@ -1783,6 +1745,7 @@
|
||||
"bbox": "Bbox",
|
||||
"move": "Move",
|
||||
"view": "View",
|
||||
"transform": "Transform",
|
||||
"colorPicker": "Color Picker"
|
||||
},
|
||||
"filter": {
|
||||
@@ -1792,13 +1755,6 @@
|
||||
"preview": "Preview",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"transform": {
|
||||
"transform": "Transform",
|
||||
"fitToBbox": "Fit to Bbox",
|
||||
"reset": "Reset",
|
||||
"apply": "Apply",
|
||||
"cancel": "Cancel"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
|
||||
@@ -16,11 +16,8 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import { AppContent } from 'features/ui/components/AppContent';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
@@ -44,18 +41,11 @@ interface Props {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: TabName;
|
||||
destination?: TabName | undefined;
|
||||
}
|
||||
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(selectLanguage);
|
||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
const logger = useLogger('system');
|
||||
const dispatch = useAppDispatch();
|
||||
const clearStorage = useClearStorage();
|
||||
@@ -93,12 +83,6 @@ const App = ({
|
||||
}
|
||||
}, [selectedWorkflowId, getAndLoadWorkflow]);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedStylePresetId) {
|
||||
dispatch(activeStylePresetIdChanged(selectedStylePresetId));
|
||||
}
|
||||
}, [dispatch, selectedStylePresetId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (destination) {
|
||||
dispatch(setActiveTab(destination));
|
||||
@@ -137,8 +121,6 @@ const App = ({
|
||||
<StylePresetModal />
|
||||
<ClearQueueConfirmationsAlertDialog />
|
||||
<PreselectedImage selectedImage={selectedImage} />
|
||||
<SettingsModal />
|
||||
<RefreshAfterResetModal />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import newGithubIssueUrl from 'new-github-issue-url';
|
||||
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
||||
@@ -15,11 +13,9 @@ type Props = {
|
||||
resetErrorBoundary: () => void;
|
||||
};
|
||||
|
||||
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
|
||||
|
||||
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const isLocal = useAppSelector(selectIsLocal);
|
||||
const isLocal = useAppSelector((s) => s.config.isLocal);
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
const text = JSON.stringify(serializeError(error), null, 2);
|
||||
|
||||
@@ -45,7 +45,6 @@ interface Props extends PropsWithChildren {
|
||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||
};
|
||||
selectedWorkflowId?: string;
|
||||
selectedStylePresetId?: string;
|
||||
destination?: TabName;
|
||||
customStarUi?: CustomStarUi;
|
||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||
@@ -67,7 +66,6 @@ const InvokeAIUI = ({
|
||||
queueId,
|
||||
selectedImage,
|
||||
selectedWorkflowId,
|
||||
selectedStylePresetId,
|
||||
destination,
|
||||
customStarUi,
|
||||
socketOptions,
|
||||
@@ -229,7 +227,6 @@ const InvokeAIUI = ({
|
||||
config={config}
|
||||
selectedImage={selectedImage}
|
||||
selectedWorkflowId={selectedWorkflowId}
|
||||
selectedStylePresetId={selectedStylePresetId}
|
||||
destination={destination}
|
||||
/>
|
||||
</AppDndContext>
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectSystemLogIsEnabled,
|
||||
selectSystemLogLevel,
|
||||
selectSystemLogNamespaces,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
|
||||
@@ -12,9 +7,9 @@ import type { LogNamespace } from './logger';
|
||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
||||
|
||||
export const useLogger = (namespace: LogNamespace) => {
|
||||
const logLevel = useAppSelector(selectSystemLogLevel);
|
||||
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
|
||||
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
|
||||
const logLevel = useAppSelector((s) => s.system.logLevel);
|
||||
const logNamespaces = useAppSelector((s) => s.system.logNamespaces);
|
||||
const logIsEnabled = useAppSelector((s) => s.system.logIsEnabled);
|
||||
|
||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
export const STORAGE_PREFIX = '@@invokeai-';
|
||||
export const EMPTY_ARRAY = [];
|
||||
export const EMPTY_OBJECT = {};
|
||||
|
||||
@@ -4,8 +4,7 @@ import {
|
||||
sessionStagingAreaImageAccepted,
|
||||
sessionStagingAreaReset,
|
||||
} from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -59,7 +58,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
const stagingAreaImage = state.canvasSession.stagedImages[index];
|
||||
|
||||
assert(stagingAreaImage, 'No staged image found to accept');
|
||||
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||
const { x, y } = state.canvasV2.present.bbox.rect;
|
||||
|
||||
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
@@ -68,7 +67,7 @@ export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||
objects: [imageObject],
|
||||
};
|
||||
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: false }));
|
||||
api.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
api.dispatch(sessionStagingAreaReset());
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
|
||||
@@ -15,12 +13,10 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
|
||||
|
||||
let wasNodeEditorReset = false;
|
||||
|
||||
const state = getState();
|
||||
const nodes = selectNodesSlice(state);
|
||||
const canvas = selectCanvasSlice(state);
|
||||
const { nodes, canvasV2 } = getState();
|
||||
|
||||
deleted_images.forEach((image_name) => {
|
||||
const imageUsage = getImageUsage(nodes, canvas, image_name);
|
||||
const imageUsage = getImageUsage(nodes.present, canvasV2.present, image_name);
|
||||
|
||||
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
||||
dispatch(nodeEditorReset());
|
||||
|
||||
@@ -31,7 +31,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
let didStartStaging = false;
|
||||
|
||||
if (!state.canvasSession.isStaging && state.canvasSettings.sendToCanvas) {
|
||||
if (!state.canvasSession.isStaging && state.canvasSession.mode === 'compose') {
|
||||
dispatch(sessionStartedStaging());
|
||||
didStartStaging = true;
|
||||
}
|
||||
@@ -70,11 +70,7 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
const { g, noise, posCond } = buildGraphResult.value;
|
||||
|
||||
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch(state, g, prepend, noise, posCond, 'generation', destination)
|
||||
);
|
||||
const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond));
|
||||
|
||||
if (isErr(prepareBatchResult)) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
@@ -12,12 +11,12 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const nodes = selectNodesSlice(state);
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
const workflow = state.workflow;
|
||||
const graph = buildNodesGraph(nodes);
|
||||
const graph = buildNodesGraph(state.nodes.present);
|
||||
const builtWorkflow = buildWorkflowWithValidation({
|
||||
nodes: nodes.nodes,
|
||||
edges: nodes.edges,
|
||||
nodes,
|
||||
edges,
|
||||
workflow,
|
||||
});
|
||||
|
||||
@@ -32,7 +31,6 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
workflow: builtWorkflow,
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
},
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
@@ -16,7 +16,7 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
|
||||
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond, 'upscaling', 'gallery');
|
||||
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||
@@ -41,7 +40,7 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
|
||||
};
|
||||
|
||||
// const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
// state.canvas.present.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
|
||||
// state.canvasV2.present.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
|
||||
// if (
|
||||
// imageObject?.image.image_name === imageDTO.image_name ||
|
||||
// processedImageObject?.image.image_name === imageDTO.image_name
|
||||
@@ -53,7 +52,7 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
|
||||
// };
|
||||
|
||||
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
|
||||
state.canvasV2.present.ipAdapters.entities.forEach((entity) => {
|
||||
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
|
||||
dispatch(ipaImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
|
||||
}
|
||||
@@ -61,7 +60,7 @@ const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO
|
||||
};
|
||||
|
||||
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
|
||||
state.canvasV2.present.rasterLayers.entities.forEach(({ id, objects }) => {
|
||||
let shouldDelete = false;
|
||||
for (const obj of objects) {
|
||||
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
|
||||
|
||||
@@ -6,8 +6,7 @@ import {
|
||||
ipaImageChanged,
|
||||
rasterLayerAdded,
|
||||
rgIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
@@ -86,7 +85,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const { x, y } = getState().canvasV2.present.bbox.rect;
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
@@ -104,7 +103,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||
const { x, y } = getState().canvasV2.present.bbox.rect;
|
||||
const overrides: Partial<CanvasControlLayerState> = {
|
||||
objects: [imageObject],
|
||||
position: { x, y },
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
|
||||
@@ -46,7 +46,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
|
||||
// handle incompatible controlnets
|
||||
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
|
||||
// state.canvasV2.present.controlAdapters.entities.forEach((ca) => {
|
||||
// if (ca.model?.base !== newBaseModel) {
|
||||
// modelsCleared += 1;
|
||||
// if (ca.isEnabled) {
|
||||
|
||||
@@ -8,10 +8,9 @@ import {
|
||||
controlLayerModelChanged,
|
||||
ipaModelChanged,
|
||||
rgIPAdapterModelChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
@@ -82,12 +81,21 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const result = zParameterModel.safeParse(defaultModelInList);
|
||||
if (result.success) {
|
||||
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
|
||||
const { bbox } = selectCanvasSlice(state);
|
||||
|
||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||
if (getIsSizeOptimal(bbox.rect.width, bbox.rect.height, optimalDimension)) {
|
||||
if (
|
||||
getIsSizeOptimal(
|
||||
state.canvasV2.present.bbox.rect.width,
|
||||
state.canvasV2.present.bbox.rect.height,
|
||||
optimalDimension
|
||||
)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
const { width, height } = calculateNewSize(bbox.aspectRatio.value, optimalDimension * optimalDimension);
|
||||
const { width, height } = calculateNewSize(
|
||||
state.canvasV2.present.bbox.aspectRatio.value,
|
||||
optimalDimension * optimalDimension
|
||||
);
|
||||
|
||||
dispatch(bboxWidthChanged({ width }));
|
||||
dispatch(bboxHeightChanged({ height }));
|
||||
@@ -170,7 +178,7 @@ const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
|
||||
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
|
||||
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
|
||||
state.canvasV2.present.controlLayers.entities.forEach((entity) => {
|
||||
const isModelAvailable = caModels.some((m) => m.key === entity.controlAdapter.model?.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
@@ -181,7 +189,7 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
|
||||
|
||||
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const ipaModels = models.filter(isIPAdapterModelConfig);
|
||||
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
|
||||
state.canvasV2.present.ipAdapters.entities.forEach((entity) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
|
||||
if (isModelAvailable) {
|
||||
return;
|
||||
@@ -189,7 +197,7 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
dispatch(ipaModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||
});
|
||||
|
||||
selectCanvasSlice(state).regions.entities.forEach((entity) => {
|
||||
state.canvasV2.present.regions.entities.forEach((entity) => {
|
||||
entity.ipAdapters.forEach(({ id: ipAdapterId, model }) => {
|
||||
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
|
||||
if (isModelAvailable) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import {
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
|
||||
@@ -2,7 +2,6 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodes } from 'features/nodes/store/selectors';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
|
||||
@@ -15,7 +14,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
startAppListening({
|
||||
actionCreator: updateAllNodesRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const nodes = selectNodes(getState());
|
||||
const { nodes } = getState().nodes.present;
|
||||
const templates = $templates.get();
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
@@ -8,9 +8,10 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
|
||||
import { canvasSessionPersistConfig, canvasSessionSlice } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { canvasPersistConfig, canvasSlice, canvasUndoableConfig } from 'features/controlLayers/store/canvasSlice';
|
||||
import { canvasV2PersistConfig, canvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { lorasPersistConfig, lorasSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { toolPersistConfig, toolSlice } from 'features/controlLayers/store/toolSlice';
|
||||
import { deleteImageModalSlice } from 'features/deleteImageModal/store/slice';
|
||||
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
@@ -57,11 +58,12 @@ const allReducers = {
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
|
||||
[canvasV2Slice.name]: undoable(canvasV2Slice.reducer),
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[upscaleSlice.name]: upscaleSlice.reducer,
|
||||
[stylePresetSlice.name]: stylePresetSlice.reducer,
|
||||
[paramsSlice.name]: paramsSlice.reducer,
|
||||
[toolSlice.name]: toolSlice.reducer,
|
||||
[canvasSettingsSlice.name]: canvasSettingsSlice.reducer,
|
||||
[canvasSessionSlice.name]: canvasSessionSlice.reducer,
|
||||
[lorasSlice.name]: lorasSlice.reducer,
|
||||
@@ -102,11 +104,12 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[canvasPersistConfig.name]: canvasPersistConfig,
|
||||
[canvasV2PersistConfig.name]: canvasV2PersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
[upscalePersistConfig.name]: upscalePersistConfig,
|
||||
[stylePresetPersistConfig.name]: stylePresetPersistConfig,
|
||||
[paramsPersistConfig.name]: paramsPersistConfig,
|
||||
[toolPersistConfig.name]: toolPersistConfig,
|
||||
[canvasSettingsPersistConfig.name]: canvasSettingsPersistConfig,
|
||||
[canvasSessionPersistConfig.name]: canvasSessionPersistConfig,
|
||||
[lorasPersistConfig.name]: lorasPersistConfig,
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Tooltip, useToken } from '@invoke-ai/ui-library';
|
||||
import type { ReactElement, ReactNode } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type IconSwitchProps = {
|
||||
isChecked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
iconChecked: ReactElement;
|
||||
tooltipChecked?: ReactNode;
|
||||
iconUnchecked: ReactElement;
|
||||
tooltipUnchecked?: ReactNode;
|
||||
ariaLabel: string;
|
||||
};
|
||||
|
||||
const getSx = (padding: string | number): SystemStyleObject => ({
|
||||
transition: 'left 0.1s ease-in-out, transform 0.1s ease-in-out',
|
||||
'&[data-checked="true"]': {
|
||||
left: `calc(100% - ${padding})`,
|
||||
transform: 'translateX(-100%)',
|
||||
},
|
||||
'&[data-checked="false"]': {
|
||||
left: padding,
|
||||
transform: 'translateX(0)',
|
||||
},
|
||||
});
|
||||
|
||||
export const IconSwitch = memo(
|
||||
({
|
||||
isChecked,
|
||||
onChange,
|
||||
iconChecked,
|
||||
tooltipChecked,
|
||||
iconUnchecked,
|
||||
tooltipUnchecked,
|
||||
ariaLabel,
|
||||
}: IconSwitchProps) => {
|
||||
const onUncheck = useCallback(() => {
|
||||
onChange(false);
|
||||
}, [onChange]);
|
||||
const onCheck = useCallback(() => {
|
||||
onChange(true);
|
||||
}, [onChange]);
|
||||
|
||||
const gap = useToken('space', 1.5);
|
||||
const sx = useMemo(() => getSx(gap), [gap]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
bg="base.800"
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
h="full"
|
||||
p={gap}
|
||||
gap={gap}
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
borderRadius="base"
|
||||
bg="invokeBlue.400"
|
||||
w={12}
|
||||
top={gap}
|
||||
bottom={gap}
|
||||
data-checked={isChecked}
|
||||
sx={sx}
|
||||
/>
|
||||
<Tooltip hasArrow label={tooltipUnchecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconUnchecked}
|
||||
onClick={onUncheck}
|
||||
variant={!isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={!isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={!isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip hasArrow label={tooltipChecked}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
fontSize={16}
|
||||
icon={iconChecked}
|
||||
onClick={onCheck}
|
||||
variant={isChecked ? 'solid' : 'ghost'}
|
||||
colorScheme={isChecked ? 'invokeBlue' : 'base'}
|
||||
aria-label={ariaLabel}
|
||||
data-checked={isChecked}
|
||||
w={12}
|
||||
alignSelf="stretch"
|
||||
h="auto"
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IconSwitch.displayName = 'IconSwitch';
|
||||
@@ -13,9 +13,8 @@ import {
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSystemSlice, setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
|
||||
import { setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type { ReactElement } from 'react';
|
||||
@@ -32,13 +31,8 @@ type Props = {
|
||||
children: ReactElement;
|
||||
};
|
||||
|
||||
const selectShouldEnableInformationalPopovers = createSelector(
|
||||
selectSystemSlice,
|
||||
(system) => system.shouldEnableInformationalPopovers
|
||||
);
|
||||
|
||||
export const InformationalPopover = memo(({ feature, children, inPortal = true, ...rest }: Props) => {
|
||||
const shouldEnableInformationalPopovers = useAppSelector(selectShouldEnableInformationalPopovers);
|
||||
const shouldEnableInformationalPopovers = useAppSelector((s) => s.system.shouldEnableInformationalPopovers);
|
||||
|
||||
const data = useMemo(() => POPOVER_DATA[feature], [feature]);
|
||||
|
||||
|
||||
@@ -1,74 +1,52 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
|
||||
type UseBoolean = {
|
||||
isTrue: boolean;
|
||||
setTrue: () => void;
|
||||
setFalse: () => void;
|
||||
set: (value: boolean) => void;
|
||||
toggle: () => void;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a hook to manage a boolean state. The boolean is stored in a nanostores atom.
|
||||
* Returns a tuple containing the hook and the atom. Use this for global boolean state.
|
||||
* @param initialValue Initial value of the boolean
|
||||
*/
|
||||
export const buildUseBoolean = (initialValue: boolean): [() => UseBoolean, WritableAtom<boolean>] => {
|
||||
const $boolean = atom(initialValue);
|
||||
|
||||
const setTrue = () => {
|
||||
$boolean.set(true);
|
||||
};
|
||||
const setFalse = () => {
|
||||
$boolean.set(false);
|
||||
};
|
||||
const set = (value: boolean) => {
|
||||
$boolean.set(value);
|
||||
};
|
||||
const toggle = () => {
|
||||
$boolean.set(!$boolean.get());
|
||||
};
|
||||
|
||||
const useBoolean = () => {
|
||||
const isTrue = useStore($boolean);
|
||||
|
||||
return {
|
||||
isTrue,
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
};
|
||||
};
|
||||
|
||||
return [useBoolean, $boolean] as const;
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook to manage a boolean state. Use this for a local boolean state.
|
||||
* @param initialValue Initial value of the boolean
|
||||
*/
|
||||
export const useBoolean = (initialValue: boolean) => {
|
||||
const [isTrue, set] = useState(initialValue);
|
||||
const setTrue = useCallback(() => set(true), []);
|
||||
const setFalse = useCallback(() => set(false), []);
|
||||
const toggle = useCallback(() => set((v) => !v), []);
|
||||
|
||||
const setTrue = useCallback(() => {
|
||||
set(true);
|
||||
}, [set]);
|
||||
const setFalse = useCallback(() => {
|
||||
set(false);
|
||||
}, [set]);
|
||||
const toggle = useCallback(() => {
|
||||
set((val) => !val);
|
||||
}, [set]);
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
isTrue,
|
||||
set,
|
||||
setTrue,
|
||||
setFalse,
|
||||
toggle,
|
||||
}),
|
||||
[isTrue, set, setTrue, setFalse, toggle]
|
||||
);
|
||||
|
||||
return {
|
||||
isTrue,
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
return api;
|
||||
};
|
||||
|
||||
export const buildUseBoolean = ($boolean: WritableAtom<boolean>) => {
|
||||
return () => {
|
||||
const setTrue = useCallback(() => {
|
||||
$boolean.set(true);
|
||||
}, []);
|
||||
const setFalse = useCallback(() => {
|
||||
$boolean.set(false);
|
||||
}, []);
|
||||
const set = useCallback((value: boolean) => {
|
||||
$boolean.set(value);
|
||||
}, []);
|
||||
const toggle = useCallback(() => {
|
||||
$boolean.set(!$boolean.get());
|
||||
}, []);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
setTrue,
|
||||
setFalse,
|
||||
set,
|
||||
toggle,
|
||||
$boolean,
|
||||
}),
|
||||
[set, setFalse, setTrue, toggle]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
@@ -27,7 +26,7 @@ const selectPostUploadAction = createMemoizedSelector(selectActiveTab, (activeTa
|
||||
|
||||
export const useFullscreenDropzone = () => {
|
||||
const { t } = useTranslation();
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
|
||||
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
||||
const postUploadAction = useAppSelector(selectPostUploadAction);
|
||||
const [uploadImage] = useUploadImageMutation();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addScope, removeScope, setScopes } from 'common/hooks/interactionScopes';
|
||||
import { useClearQueue } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { useClearQueue } from 'features/queue/hooks/useClearQueue';
|
||||
import { useQueueBack } from 'features/queue/hooks/useQueueBack';
|
||||
import { useQueueFront } from 'features/queue/hooks/useQueueFront';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { groupBy, reduce } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@@ -30,13 +28,11 @@ const groupByBaseFunc = <T extends AnyModelConfig>(model: T) => model.base.toUpp
|
||||
const groupByBaseAndTypeFunc = <T extends AnyModelConfig>(model: T) =>
|
||||
`${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`;
|
||||
|
||||
const selectBaseWithSDXLFallback = createSelector(selectParamsSlice, (params) => params.model?.base ?? 'sdxl');
|
||||
|
||||
export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
arg: UseGroupedModelComboboxArg<T>
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base = useAppSelector(selectBaseWithSDXLFallback);
|
||||
const base_model = useAppSelector((s) => s.params.model?.base ?? 'sdxl');
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelConfigs) {
|
||||
@@ -58,9 +54,9 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
},
|
||||
[] as GroupBase<ComboboxOption>[]
|
||||
);
|
||||
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base) ? -1 : 1));
|
||||
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1));
|
||||
return _options;
|
||||
}, [modelConfigs, groupByType, getIsDisabled, base]);
|
||||
}, [modelConfigs, groupByType, getIsDisabled, base_model]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { useCallback } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||
@@ -30,7 +29,7 @@ type UseImageUploadButtonArgs = {
|
||||
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
|
||||
*/
|
||||
export const useImageUploadButton = ({ postUploadAction, isDisabled }: UseImageUploadButtonArgs) => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
|
||||
const [uploadImage] = useUploadImageMutation();
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
|
||||
@@ -3,11 +3,10 @@ import { $isConnected } from 'app/hooks/useSocketIO';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { selectCanvasV2Slice } from 'features/controlLayers/store/selectors';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
@@ -35,14 +34,14 @@ const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
selectNodesSlice,
|
||||
selectWorkflowSettingsSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
selectCanvasSlice,
|
||||
selectCanvasV2Slice,
|
||||
selectParamsSlice,
|
||||
selectUpscalelice,
|
||||
selectConfigSlice,
|
||||
selectActiveTab,
|
||||
],
|
||||
(system, nodes, workflowSettings, dynamicPrompts, canvas, params, upscale, config, activeTabName) => {
|
||||
const { bbox } = canvas;
|
||||
(system, nodes, workflowSettings, dynamicPrompts, canvasV2, params, upscale, config, activeTabName) => {
|
||||
const { bbox } = canvasV2;
|
||||
const { model, positivePrompt } = params;
|
||||
|
||||
const reasons: { prefix?: string; content: string }[] = [];
|
||||
@@ -125,10 +124,10 @@ const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
canvas.controlLayers.entities
|
||||
canvasV2.controlLayers.entities
|
||||
.filter((controlLayer) => controlLayer.isEnabled)
|
||||
.forEach((controlLayer, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -155,10 +154,10 @@ const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
}
|
||||
});
|
||||
|
||||
canvas.ipAdapters.entities
|
||||
canvasV2.ipAdapters.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -183,10 +182,10 @@ const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
}
|
||||
});
|
||||
|
||||
canvas.regions.entities
|
||||
canvasV2.regions.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
@@ -220,10 +219,10 @@ const createSelector = (templates: Templates, isConnected: boolean) =>
|
||||
}
|
||||
});
|
||||
|
||||
canvas.rasterLayers.entities
|
||||
canvasV2.rasterLayers.entities
|
||||
.filter((entity) => entity.isEnabled)
|
||||
.forEach((entity, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layer_one');
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
export const stopPropagation = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
export const preventDefault = (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
@@ -19,17 +18,12 @@ const selectImagesToChange = createMemoizedSelector(
|
||||
(changeBoardModal) => changeBoardModal.imagesToChange
|
||||
);
|
||||
|
||||
const selectIsModalOpen = createSelector(
|
||||
selectChangeBoardModalSlice,
|
||||
(changeBoardModal) => changeBoardModal.isModalOpen
|
||||
);
|
||||
|
||||
const ChangeBoardModal = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
|
||||
const isModalOpen = useAppSelector(selectIsModalOpen);
|
||||
const isModalOpen = useAppSelector((s) => s.changeBoardModal.isModalOpen);
|
||||
const imagesToChange = useAppSelector(selectImagesToChange);
|
||||
const [addImagesToBoard] = useAddImagesToBoardMutation();
|
||||
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
|
||||
@@ -83,7 +77,6 @@ const ChangeBoardModal = () => {
|
||||
acceptCallback={handleChangeBoard}
|
||||
acceptButtonText={t('boards.move')}
|
||||
cancelButtonText={t('boards.cancel')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Text>
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
@@ -31,22 +31,22 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center">
|
||||
<ButtonGroup position="relative" orientation="vertical" isAttached={false} top="20%">
|
||||
<Flex flexDir="column" w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<ButtonGroup orientation="vertical" isAttached={false}>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
{t('controlLayers.inpaintMask', { count: 1 })}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
{t('controlLayers.regionalGuidance', { count: 1 })}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
{t('controlLayers.rasterLayer', { count: 1 })}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
{t('controlLayers.controlLayer', { count: 1 })}
|
||||
</Button>
|
||||
<Button variant="ghost" justifyContent="flex-start" leftIcon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.globalIPAdapter')}
|
||||
{t('controlLayers.ipAdapter', { count: 1 })}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { CanvasEntityOpacity } from 'features/controlLayers/components/common/CanvasEntityOpacity';
|
||||
import { ControlLayerEntityList } from 'features/controlLayers/components/ControlLayer/ControlLayerEntityList';
|
||||
import { InpaintMaskList } from 'features/controlLayers/components/InpaintMask/InpaintMaskList';
|
||||
import { IPAdapterList } from 'features/controlLayers/components/IPAdapter/IPAdapterList';
|
||||
@@ -10,7 +11,8 @@ import { memo } from 'react';
|
||||
export const CanvasEntityList = memo(() => {
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={2} data-testid="control-layers-layer-list" w="full" h="full">
|
||||
<Flex flexDir="column" gap={4} pt={2} data-testid="control-layers-layer-list" w="full" h="full">
|
||||
<CanvasEntityOpacity />
|
||||
<InpaintMaskList />
|
||||
<RegionalGuidanceEntityList />
|
||||
<IPAdapterList />
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityListMenuItems } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarAddLayerMenuItems';
|
||||
import { CanvasEntityListMenuItems } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityListMenuItems';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import { PiDotsThreeOutlineFill } from 'react-icons/pi';
|
||||
|
||||
export const EntityListActionBarAddLayerButton = memo(() => {
|
||||
export const CanvasEntityListMenuButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
size="sm"
|
||||
tooltip={t('controlLayers.addLayer')}
|
||||
aria-label={t('controlLayers.addLayer')}
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
aria-label={t('accessibility.menu')}
|
||||
icon={<PiDotsThreeOutlineFill />}
|
||||
variant="link"
|
||||
data-testid="control-layers-add-layer-menu-button"
|
||||
alignSelf="stretch"
|
||||
/>
|
||||
<MenuList>
|
||||
<CanvasEntityListMenuItems />
|
||||
@@ -25,4 +24,4 @@ export const EntityListActionBarAddLayerButton = memo(() => {
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBarAddLayerButton.displayName = 'EntityListActionBarAddLayerButton';
|
||||
CanvasEntityListMenuButton.displayName = 'CanvasEntityListMenuButton';
|
||||
@@ -1,21 +1,25 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useDefaultIPAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
|
||||
import { MenuDivider, MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
allEntitiesDeleted,
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
ipaAdded,
|
||||
rasterLayerAdded,
|
||||
rgAdded,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
} from 'features/controlLayers/store/canvasV2Slice';
|
||||
import { selectEntityCount } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import { PiPlusBold, PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityListMenuItems = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultIPAdapter = useDefaultIPAdapter();
|
||||
const hasEntities = useAppSelector((s) => {
|
||||
const count = selectEntityCount(s);
|
||||
return count > 0;
|
||||
});
|
||||
const addInpaintMask = useCallback(() => {
|
||||
dispatch(inpaintMaskAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
@@ -29,26 +33,32 @@ export const CanvasEntityListMenuItems = memo(() => {
|
||||
dispatch(controlLayerAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const addIPAdapter = useCallback(() => {
|
||||
const overrides = { ipAdapter: defaultIPAdapter };
|
||||
dispatch(ipaAdded({ isSelected: true, overrides }));
|
||||
}, [defaultIPAdapter, dispatch]);
|
||||
dispatch(ipaAdded({ isSelected: true }));
|
||||
}, [dispatch]);
|
||||
const deleteAll = useCallback(() => {
|
||||
dispatch(allEntitiesDeleted());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
{t('controlLayers.inpaintMask', { count: 1 })}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
{t('controlLayers.regionalGuidance', { count: 1 })}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
{t('controlLayers.rasterLayer', { count: 1 })}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
{t('controlLayers.controlLayer', { count: 1 })}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addIPAdapter}>
|
||||
{t('controlLayers.globalIPAdapter')}
|
||||
{t('controlLayers.ipAdapter', { count: 1 })}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem onClick={deleteAll} icon={<PiTrashSimpleBold />} color="error.300" isDisabled={!hasEntities}>
|
||||
{t('controlLayers.deleteAll', { count: 1 })}
|
||||
</MenuItem>
|
||||
</>
|
||||
);
|
||||
@@ -1,20 +0,0 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { EntityListActionBarAddLayerButton } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarAddLayerMenuButton';
|
||||
import { EntityListActionBarDeleteButton } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarDeleteButton';
|
||||
import { EntityListActionBarSelectedEntityFill } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarSelectedEntityFill';
|
||||
import { SelectedEntityOpacity } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBarSelectedEntityOpacity';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const EntityListActionBar = memo(() => {
|
||||
return (
|
||||
<Flex w="full" py={1} px={1} gap={2} alignItems="center">
|
||||
<SelectedEntityOpacity />
|
||||
<Spacer />
|
||||
<EntityListActionBarSelectedEntityFill />
|
||||
<EntityListActionBarAddLayerButton />
|
||||
<EntityListActionBarDeleteButton />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBar.displayName = 'EntityListActionBar';
|
||||
@@ -1,39 +0,0 @@
|
||||
import { IconButton, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { allEntitiesDeleted, entityDeleted } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectEntityCount, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleFill } from 'react-icons/pi';
|
||||
|
||||
export const EntityListActionBarDeleteButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const entityCount = useAppSelector(selectEntityCount);
|
||||
const shift = useShiftModifier();
|
||||
const onClick = useCallback(() => {
|
||||
if (shift) {
|
||||
dispatch(allEntitiesDeleted());
|
||||
return;
|
||||
}
|
||||
if (!selectedEntityIdentifier) {
|
||||
return;
|
||||
}
|
||||
dispatch(entityDeleted({ entityIdentifier: selectedEntityIdentifier }));
|
||||
}, [dispatch, selectedEntityIdentifier, shift]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
isDisabled={shift ? entityCount === 0 : !selectedEntityIdentifier}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
aria-label={shift ? t('controlLayers.deleteAll') : t('controlLayers.deleteSelected')}
|
||||
tooltip={shift ? t('controlLayers.deleteAll') : t('controlLayers.deleteSelected')}
|
||||
icon={<PiTrashSimpleFill />}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBarDeleteButton.displayName = 'EntityListActionBarDeleteButton';
|
||||
@@ -1,70 +0,0 @@
|
||||
import { Box, Flex, Popover, PopoverBody, PopoverContent, PopoverTrigger, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import RgbColorPicker from 'common/components/RgbColorPicker';
|
||||
import { rgbColorToString } from 'common/util/colorCodeTransformers';
|
||||
import { MaskFillStyle } from 'features/controlLayers/components/common/MaskFillStyle';
|
||||
import { entityFillColorChanged, entityFillStyleChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectSelectedEntityFill, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import { type FillStyle, isMaskEntityIdentifier, type RgbColor } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const EntityListActionBarSelectedEntityFill = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const fill = useAppSelector(selectSelectedEntityFill);
|
||||
|
||||
const onChangeFillColor = useCallback(
|
||||
(color: RgbColor) => {
|
||||
if (!selectedEntityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isMaskEntityIdentifier(selectedEntityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
dispatch(entityFillColorChanged({ entityIdentifier: selectedEntityIdentifier, color }));
|
||||
},
|
||||
[dispatch, selectedEntityIdentifier]
|
||||
);
|
||||
const onChangeFillStyle = useCallback(
|
||||
(style: FillStyle) => {
|
||||
if (!selectedEntityIdentifier) {
|
||||
return;
|
||||
}
|
||||
if (!isMaskEntityIdentifier(selectedEntityIdentifier)) {
|
||||
return;
|
||||
}
|
||||
dispatch(entityFillStyleChanged({ entityIdentifier: selectedEntityIdentifier, style }));
|
||||
},
|
||||
[dispatch, selectedEntityIdentifier]
|
||||
);
|
||||
|
||||
if (!selectedEntityIdentifier || !fill) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<Flex role="button" aria-label={t('controlLayers.maskFill')} tabIndex={-1} w={8} h={8}>
|
||||
<Tooltip label={t('controlLayers.maskFill')}>
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Box borderRadius="full" w={6} h={6} borderWidth={1} bg={rgbColorToString(fill.color)} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverBody minH={64}>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<RgbColorPicker color={fill.color} onChange={onChangeFillColor} withNumberInput />
|
||||
<MaskFillStyle style={fill.style} onChange={onChangeFillStyle} />
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
});
|
||||
|
||||
EntityListActionBarSelectedEntityFill.displayName = 'EntityListActionBarSelectedEntityFill';
|
||||
@@ -0,0 +1,26 @@
|
||||
import { Button, ButtonGroup } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { sessionModeChanged } from 'features/controlLayers/store/canvasSessionSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasModeSwitcher = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const mode = useAppSelector((s) => s.canvasSession.mode);
|
||||
const onClickGenerate = useCallback(() => dispatch(sessionModeChanged({ mode: 'generate' })), [dispatch]);
|
||||
const onClickCompose = useCallback(() => dispatch(sessionModeChanged({ mode: 'compose' })), [dispatch]);
|
||||
|
||||
return (
|
||||
<ButtonGroup variant="outline">
|
||||
<Button onClick={onClickGenerate} colorScheme={mode === 'generate' ? 'invokeBlue' : 'base'}>
|
||||
{t('controlLayers.generateMode')}
|
||||
</Button>
|
||||
<Button onClick={onClickCompose} colorScheme={mode === 'compose' ? 'invokeBlue' : 'base'}>
|
||||
{t('controlLayers.composeMode')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasModeSwitcher.displayName = 'CanvasModeSwitcher';
|
||||
@@ -1,23 +1,32 @@
|
||||
import { Divider, Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, ContextMenu, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasAddEntityButtons } from 'features/controlLayers/components/CanvasAddEntityButtons';
|
||||
import { CanvasEntityList } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityList';
|
||||
import { EntityListActionBar } from 'features/controlLayers/components/CanvasEntityList/EntityListActionBar';
|
||||
import { CanvasEntityListMenuItems } from 'features/controlLayers/components/CanvasEntityList/CanvasEntityListMenuItems';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectHasEntities } from 'features/controlLayers/store/selectors';
|
||||
import { memo } from 'react';
|
||||
import { selectEntityCount } from 'features/controlLayers/store/selectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
export const CanvasPanelContent = memo(() => {
|
||||
const hasEntities = useAppSelector(selectHasEntities);
|
||||
|
||||
const hasEntities = useAppSelector((s) => selectEntityCount(s) > 0);
|
||||
const renderMenu = useCallback(
|
||||
() => (
|
||||
<MenuList>
|
||||
<CanvasEntityListMenuItems />
|
||||
</MenuList>
|
||||
),
|
||||
[]
|
||||
);
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
<Flex flexDir="column" gap={2} w="full" h="full">
|
||||
<EntityListActionBar />
|
||||
<Divider py={0} />
|
||||
{!hasEntities && <CanvasAddEntityButtons />}
|
||||
{hasEntities && <CanvasEntityList />}
|
||||
</Flex>
|
||||
<ContextMenu<HTMLDivElement> renderMenu={renderMenu}>
|
||||
{(ref) => (
|
||||
<Box ref={ref} w="full" h="full">
|
||||
{!hasEntities && <CanvasAddEntityButtons />}
|
||||
{hasEntities && <CanvasEntityList />}
|
||||
</Box>
|
||||
)}
|
||||
</ContextMenu>
|
||||
</CanvasManagerProviderGate>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { $alt, IconButton } from '@invoke-ai/ui-library';
|
||||
import { $shift, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { INTERACTION_SCOPES } from 'common/hooks/interactionScopes';
|
||||
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
@@ -7,7 +7,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasToolbarResetViewButton = memo(() => {
|
||||
export const CanvasResetViewButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const canvasManager = useStore($canvasManager);
|
||||
const isCanvasActive = useStore(INTERACTION_SCOPES.canvas.$isActive);
|
||||
@@ -27,7 +27,7 @@ export const CanvasToolbarResetViewButton = memo(() => {
|
||||
}, [canvasManager]);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
if ($alt.get()) {
|
||||
if ($shift.get()) {
|
||||
resetView();
|
||||
} else {
|
||||
resetZoom();
|
||||
@@ -35,7 +35,7 @@ export const CanvasToolbarResetViewButton = memo(() => {
|
||||
}, [resetView, resetZoom]);
|
||||
|
||||
useHotkeys('r', resetView, { enabled: isCanvasActive }, [isCanvasActive]);
|
||||
useHotkeys('alt+r', resetZoom, { enabled: isCanvasActive }, [isCanvasActive]);
|
||||
useHotkeys('shift+r', resetZoom, { enabled: isCanvasActive }, [isCanvasActive]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -48,4 +48,4 @@ export const CanvasToolbarResetViewButton = memo(() => {
|
||||
);
|
||||
});
|
||||
|
||||
CanvasToolbarResetViewButton.displayName = 'CanvasToolbarResetViewButton';
|
||||
CanvasResetViewButton.displayName = 'CanvasResetViewButton';
|
||||
@@ -15,8 +15,9 @@ import {
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { MAX_CANVAS_SCALE, MIN_CANVAS_SCALE } from 'features/controlLayers/konva/constants';
|
||||
import { snapToNearest } from 'features/controlLayers/konva/util';
|
||||
import { round } from 'lodash-es';
|
||||
import { clamp, round } from 'lodash-es';
|
||||
import { computed } from 'nanostores';
|
||||
import type { KeyboardEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
@@ -31,7 +32,7 @@ function formatPct(v: number | string) {
|
||||
return `${round(Number(v), 2).toLocaleString()}%`;
|
||||
}
|
||||
|
||||
function mapSliderValueToRawValue(value: number) {
|
||||
function mapSliderValueToScale(value: number) {
|
||||
if (value <= 40) {
|
||||
// 0 to 40 -> 10% to 100%
|
||||
return 10 + (90 * value) / 40;
|
||||
@@ -44,58 +45,64 @@ function mapSliderValueToRawValue(value: number) {
|
||||
}
|
||||
}
|
||||
|
||||
function mapRawValueToSliderValue(value: number) {
|
||||
if (value <= 100) {
|
||||
return ((value - 10) * 40) / 90;
|
||||
} else if (value <= 500) {
|
||||
return 40 + ((value - 100) * 30) / 400;
|
||||
function mapScaleToSliderValue(scale: number) {
|
||||
if (scale <= 100) {
|
||||
return ((scale - 10) * 40) / 90;
|
||||
} else if (scale <= 500) {
|
||||
return 40 + ((scale - 100) * 30) / 400;
|
||||
} else {
|
||||
return 70 + ((value - 500) * 30) / 1500;
|
||||
return 70 + ((scale - 500) * 30) / 1500;
|
||||
}
|
||||
}
|
||||
|
||||
function formatSliderValue(value: number) {
|
||||
return String(mapSliderValueToRawValue(value));
|
||||
return String(mapSliderValueToScale(value));
|
||||
}
|
||||
|
||||
const marks = [
|
||||
mapRawValueToSliderValue(10),
|
||||
mapRawValueToSliderValue(50),
|
||||
mapRawValueToSliderValue(100),
|
||||
mapRawValueToSliderValue(500),
|
||||
mapRawValueToSliderValue(2000),
|
||||
mapScaleToSliderValue(10),
|
||||
mapScaleToSliderValue(50),
|
||||
mapScaleToSliderValue(100),
|
||||
mapScaleToSliderValue(500),
|
||||
mapScaleToSliderValue(2000),
|
||||
];
|
||||
|
||||
const sliderDefaultValue = mapRawValueToSliderValue(100);
|
||||
const sliderDefaultValue = mapScaleToSliderValue(100);
|
||||
|
||||
const snapCandidates = marks.slice(1, marks.length - 1);
|
||||
|
||||
export const CanvasToolbarScale = memo(() => {
|
||||
export const CanvasScale = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const canvasManager = useCanvasManager();
|
||||
const scale = useStore(computed(canvasManager.stage.$stageAttrs, (attrs) => attrs.scale));
|
||||
const scale = useStore(computed(canvasManager.stateApi.$stageAttrs, (attrs) => attrs.scale));
|
||||
const [localScale, setLocalScale] = useState(scale * 100);
|
||||
|
||||
const onChangeSlider = useCallback(
|
||||
(scale: number) => {
|
||||
if (!canvasManager) {
|
||||
return;
|
||||
}
|
||||
let snappedScale = scale;
|
||||
// Do not snap if shift key is held
|
||||
if (!$shift.get()) {
|
||||
snappedScale = snapToNearest(scale, snapCandidates, 2);
|
||||
}
|
||||
const mappedScale = mapSliderValueToRawValue(snappedScale);
|
||||
const mappedScale = mapSliderValueToScale(snappedScale);
|
||||
canvasManager.stage.setScale(mappedScale / 100);
|
||||
},
|
||||
[canvasManager]
|
||||
);
|
||||
|
||||
const onBlur = useCallback(() => {
|
||||
if (!canvasManager) {
|
||||
return;
|
||||
}
|
||||
if (isNaN(Number(localScale))) {
|
||||
canvasManager.stage.setScale(1);
|
||||
setLocalScale(100);
|
||||
return;
|
||||
}
|
||||
canvasManager.stage.setScale(localScale / 100);
|
||||
canvasManager.stage.setScale(clamp(localScale / 100, MIN_CANVAS_SCALE, MAX_CANVAS_SCALE));
|
||||
}, [canvasManager, localScale]);
|
||||
|
||||
const onChangeNumberInput = useCallback((valueAsString: string, valueAsNumber: number) => {
|
||||
@@ -123,8 +130,8 @@ export const CanvasToolbarScale = memo(() => {
|
||||
<NumberInput
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
min={canvasManager.stage.config.MIN_SCALE * 100}
|
||||
max={canvasManager.stage.config.MAX_SCALE * 100}
|
||||
min={MIN_CANVAS_SCALE * 100}
|
||||
max={MAX_CANVAS_SCALE * 100}
|
||||
value={localScale}
|
||||
onChange={onChangeNumberInput}
|
||||
onBlur={onBlur}
|
||||
@@ -155,7 +162,7 @@ export const CanvasToolbarScale = memo(() => {
|
||||
<CompositeSlider
|
||||
min={0}
|
||||
max={100}
|
||||
value={mapRawValueToSliderValue(localScale)}
|
||||
value={mapScaleToSliderValue(localScale)}
|
||||
onChange={onChangeSlider}
|
||||
defaultValue={sliderDefaultValue}
|
||||
marks={marks}
|
||||
@@ -168,4 +175,4 @@ export const CanvasToolbarScale = memo(() => {
|
||||
);
|
||||
});
|
||||
|
||||
CanvasToolbarScale.displayName = 'CanvasToolbarScale';
|
||||
CanvasScale.displayName = 'CanvasScale';
|
||||
@@ -1,65 +0,0 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { IconSwitch } from 'common/components/IconSwitch';
|
||||
import {
|
||||
selectCanvasSettingsSlice,
|
||||
settingsSendToCanvasChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold, PiPaintBrushBold } from 'react-icons/pi';
|
||||
|
||||
const TooltipSendToGallery = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToGallery')}</Text>
|
||||
<Text fontWeight="normal">{t('controlLayers.sendToGalleryDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipSendToGallery.displayName = 'TooltipSendToGallery';
|
||||
|
||||
const TooltipSendToCanvas = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">{t('controlLayers.sendToCanvas')}</Text>
|
||||
<Text fontWeight="normal">{t('controlLayers.sendToCanvasDesc')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
TooltipSendToCanvas.displayName = 'TooltipSendToCanvas';
|
||||
|
||||
const selectSendToCanvas = createSelector(selectCanvasSettingsSlice, (canvasSettings) => canvasSettings.sendToCanvas);
|
||||
|
||||
export const CanvasSendToToggle = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const sendToCanvas = useAppSelector(selectSendToCanvas);
|
||||
|
||||
const onChange = useCallback(
|
||||
(isChecked: boolean) => {
|
||||
dispatch(settingsSendToCanvasChanged(isChecked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IconSwitch
|
||||
isChecked={sendToCanvas}
|
||||
onChange={onChange}
|
||||
iconUnchecked={<PiImageBold />}
|
||||
tooltipUnchecked={<TooltipSendToGallery />}
|
||||
iconChecked={<PiPaintBrushBold />}
|
||||
tooltipChecked={<TooltipSendToCanvas />}
|
||||
ariaLabel="Toggle canvas mode"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSendToToggle.displayName = 'CanvasSendToToggle';
|
||||
@@ -1,13 +1,13 @@
|
||||
import { Spacer } from '@invoke-ai/ui-library';
|
||||
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
|
||||
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
|
||||
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
|
||||
import { CanvasEntityHeaderCommonActions } from 'features/controlLayers/components/common/CanvasEntityHeaderCommonActions';
|
||||
import { CanvasEntityPreviewImage } from 'features/controlLayers/components/common/CanvasEntityPreviewImage';
|
||||
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
|
||||
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
|
||||
import { ControlLayerBadges } from 'features/controlLayers/components/ControlLayer/ControlLayerBadges';
|
||||
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
|
||||
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { EntityLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
|
||||
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
@@ -21,20 +21,20 @@ export const ControlLayer = memo(({ id }: Props) => {
|
||||
|
||||
return (
|
||||
<EntityIdentifierContext.Provider value={entityIdentifier}>
|
||||
<ControlLayerAdapterGate>
|
||||
<EntityLayerAdapterGate>
|
||||
<CanvasEntityContainer>
|
||||
<CanvasEntityHeader>
|
||||
<CanvasEntityPreviewImage />
|
||||
<CanvasEntityEditableTitle />
|
||||
<Spacer />
|
||||
<ControlLayerBadges />
|
||||
<CanvasEntityHeaderCommonActions />
|
||||
<CanvasEntityEnabledToggle />
|
||||
</CanvasEntityHeader>
|
||||
<CanvasEntitySettingsWrapper>
|
||||
<ControlLayerControlAdapter />
|
||||
</CanvasEntitySettingsWrapper>
|
||||
</CanvasEntityContainer>
|
||||
</ControlLayerAdapterGate>
|
||||
</EntityLayerAdapterGate>
|
||||
</EntityIdentifierContext.Provider>
|
||||
);
|
||||
});
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user