mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 22:48:02 -05:00
Compare commits
548 Commits
ryan/sd35
...
ryan/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2144d21f80 | ||
|
|
958efa19d7 | ||
|
|
11af57def3 | ||
|
|
8b70a5b9bd | ||
|
|
5d9fdcd78d | ||
|
|
c7b84cf012 | ||
|
|
8e409e3436 | ||
|
|
987393853c | ||
|
|
91c5af1b95 | ||
|
|
5c67dd507a | ||
|
|
2ff928ec17 | ||
|
|
4327bbe77e | ||
|
|
ad1c0d37ef | ||
|
|
9708d87946 | ||
|
|
3ad44f7850 | ||
|
|
9a482981b2 | ||
|
|
6b02362b12 | ||
|
|
8fec4ec91c | ||
|
|
693e421970 | ||
|
|
dc14104bc8 | ||
|
|
f286a1d1f3 | ||
|
|
9dc86b2b71 | ||
|
|
2cab689b79 | ||
|
|
f8c7adddd0 | ||
|
|
17da1d92e9 | ||
|
|
1cc57a4854 | ||
|
|
3993fae331 | ||
|
|
1446526d55 | ||
|
|
62c024e725 | ||
|
|
1e92bb4e94 | ||
|
|
db6398fdf6 | ||
|
|
ebd73a2ac2 | ||
|
|
8ee95cab00 | ||
|
|
d1184201a8 | ||
|
|
5887891654 | ||
|
|
765ca4e004 | ||
|
|
159b00a490 | ||
|
|
3fbf6f2d2a | ||
|
|
931fca7cd1 | ||
|
|
db84a3a5d4 | ||
|
|
ca8313e805 | ||
|
|
df849035ee | ||
|
|
8d97fe69ca | ||
|
|
9044e53a9b | ||
|
|
6012b0f912 | ||
|
|
bb0ed5dc8a | ||
|
|
021552fd81 | ||
|
|
be73dbba92 | ||
|
|
db9c0cad7c | ||
|
|
54b7f9a063 | ||
|
|
7d488a5352 | ||
|
|
4d7667f63d | ||
|
|
08704ee8ec | ||
|
|
5910892c33 | ||
|
|
46a09d9e90 | ||
|
|
df0c7d73f3 | ||
|
|
3905c97e32 | ||
|
|
0be796a808 | ||
|
|
7dd33b0f39 | ||
|
|
484aaf1595 | ||
|
|
c276b60af9 | ||
|
|
5d8dd6e26e | ||
|
|
5bca68d873 | ||
|
|
64364e7911 | ||
|
|
6565cea039 | ||
|
|
3ebd8d6c07 | ||
|
|
e970185161 | ||
|
|
fa5653cdf7 | ||
|
|
9a7b000995 | ||
|
|
3a27242838 | ||
|
|
8cfb032051 | ||
|
|
06a9d4e2b2 | ||
|
|
ed46acee79 | ||
|
|
b54463d294 | ||
|
|
faee79dc95 | ||
|
|
965cd76e33 | ||
|
|
e5e8cbf34c | ||
|
|
3412a52594 | ||
|
|
e01f66b026 | ||
|
|
53abdde242 | ||
|
|
94c088300f | ||
|
|
3741a6f5e0 | ||
|
|
059336258f | ||
|
|
2c23b8414c | ||
|
|
271cc52c80 | ||
|
|
20356c0746 | ||
|
|
e44458609f | ||
|
|
69d86a7696 | ||
|
|
56db1a9292 | ||
|
|
cf50e5eeee | ||
|
|
c9c07968d2 | ||
|
|
97d0757176 | ||
|
|
0f51b677a9 | ||
|
|
56ca94c3a9 | ||
|
|
28d169f859 | ||
|
|
92f71d99ee | ||
|
|
0764c02b1d | ||
|
|
081c7569fe | ||
|
|
20f6532ee8 | ||
|
|
b9e8910478 | ||
|
|
ded8391e3c | ||
|
|
e9dd2c396a | ||
|
|
0d86de0cb5 | ||
|
|
bad1149504 | ||
|
|
fda7aaa7ca | ||
|
|
85c616fa34 | ||
|
|
549f4e9794 | ||
|
|
ef8ededd2f | ||
|
|
1948ffe106 | ||
|
|
c70f4404c4 | ||
|
|
b157ae928c | ||
|
|
7a0871992d | ||
|
|
b38e2e14f4 | ||
|
|
7c0e70ec84 | ||
|
|
a89ae9d2bf | ||
|
|
ad1fcb3f07 | ||
|
|
87d74b910b | ||
|
|
7ad1c297a4 | ||
|
|
fbc629faa6 | ||
|
|
7baa6b3c09 | ||
|
|
53d482bade | ||
|
|
5aca04b51b | ||
|
|
ea8787c8ff | ||
|
|
cead2c4445 | ||
|
|
f76ac1808c | ||
|
|
f01210861b | ||
|
|
f757f23ef0 | ||
|
|
872a6ef209 | ||
|
|
4267e5ffc4 | ||
|
|
a69c5ff9ef | ||
|
|
3ebd8d7d1b | ||
|
|
1fd80d54a4 | ||
|
|
991f63e455 | ||
|
|
6a1efd3527 | ||
|
|
0eadc0dd9e | ||
|
|
481423d678 | ||
|
|
89ede0aef3 | ||
|
|
359bdee9c6 | ||
|
|
0e6fba3763 | ||
|
|
652502d7a6 | ||
|
|
91d981a49e | ||
|
|
24f61d21b2 | ||
|
|
eb9a4177c5 | ||
|
|
3c43351a5b | ||
|
|
b1359b6dff | ||
|
|
bddccf6d2f | ||
|
|
21ffaab2a2 | ||
|
|
1e969f938f | ||
|
|
9c6c86ee4f | ||
|
|
6b53a48b48 | ||
|
|
c813fa3fc0 | ||
|
|
a08e61184a | ||
|
|
a0d62a5f41 | ||
|
|
616c0f11e1 | ||
|
|
e1626a4e49 | ||
|
|
6ab891a319 | ||
|
|
492de41316 | ||
|
|
c064efc866 | ||
|
|
1a0885bfb1 | ||
|
|
e8b202d0a5 | ||
|
|
c6fc82f756 | ||
|
|
9a77e951d2 | ||
|
|
8bd4207a27 | ||
|
|
0bb601aaf7 | ||
|
|
2da25a0043 | ||
|
|
51d0931898 | ||
|
|
357b68d1ba | ||
|
|
d9ddb6c32e | ||
|
|
ad02a99a83 | ||
|
|
b707dafc7b | ||
|
|
02906c8f5d | ||
|
|
8538e508f1 | ||
|
|
8c333ffd14 | ||
|
|
72ace5fdff | ||
|
|
9b7583fc84 | ||
|
|
989eee338e | ||
|
|
acc3d7b91b | ||
|
|
49de868658 | ||
|
|
b1702c7d90 | ||
|
|
e49e19ea13 | ||
|
|
c9f91f391e | ||
|
|
4cb6b2b701 | ||
|
|
7d132ea148 | ||
|
|
1088accd91 | ||
|
|
8d237d8f8b | ||
|
|
0c86a3232d | ||
|
|
dbfb0359cb | ||
|
|
b4c2aa596b | ||
|
|
87e89b7995 | ||
|
|
9b089430e2 | ||
|
|
f2b0025958 | ||
|
|
4b390906bc | ||
|
|
c5b8efe03b | ||
|
|
4d08d00ad8 | ||
|
|
9b0130262b | ||
|
|
878093f64e | ||
|
|
d5ff7ef250 | ||
|
|
f36583f866 | ||
|
|
829bc1bc7d | ||
|
|
17c7b57145 | ||
|
|
6a12189542 | ||
|
|
96a31a5563 | ||
|
|
067747eca9 | ||
|
|
c7878fddc6 | ||
|
|
54c51e0a06 | ||
|
|
1640ea0298 | ||
|
|
0c32ae9775 | ||
|
|
fdb8ca5165 | ||
|
|
571faf6d7c | ||
|
|
bdbdb22b74 | ||
|
|
9bbb5644af | ||
|
|
e90ad19f22 | ||
|
|
0ba11e8f73 | ||
|
|
1cf7600f5b | ||
|
|
4f9d12b872 | ||
|
|
68c3b0649b | ||
|
|
8ef8bd4261 | ||
|
|
50897ba066 | ||
|
|
3510643870 | ||
|
|
ca9cb1c9ef | ||
|
|
b89caa02bd | ||
|
|
eaf4e08c44 | ||
|
|
fb19621361 | ||
|
|
9179619077 | ||
|
|
13cb5f0ba2 | ||
|
|
7e52fc1c17 | ||
|
|
7f60a4a282 | ||
|
|
3f880496f7 | ||
|
|
f05efd3270 | ||
|
|
79eb8172b6 | ||
|
|
7732b5d478 | ||
|
|
a2a1934b66 | ||
|
|
dff6570078 | ||
|
|
04e4fb63af | ||
|
|
83609d5008 | ||
|
|
2618ed0ae7 | ||
|
|
bb3cedddd5 | ||
|
|
5b3e1593ca | ||
|
|
2d08078a7d | ||
|
|
75acece1f1 | ||
|
|
a9db2ffefd | ||
|
|
cdd148b4d1 | ||
|
|
730fabe2de | ||
|
|
6c59790a7f | ||
|
|
0e6cb91863 | ||
|
|
a0fefcd43f | ||
|
|
c37251d6f7 | ||
|
|
2854210162 | ||
|
|
5545b980af | ||
|
|
0c9434c464 | ||
|
|
8771de917d | ||
|
|
122946ef4c | ||
|
|
2d974f670c | ||
|
|
75f0da9c35 | ||
|
|
5df3c00e28 | ||
|
|
b049880502 | ||
|
|
e5293fdd1a | ||
|
|
8883775762 | ||
|
|
cfadb313d2 | ||
|
|
b5cadd9a1a | ||
|
|
5361b6e014 | ||
|
|
ff346172af | ||
|
|
92f660018b | ||
|
|
1afc2cba4e | ||
|
|
ee8359242c | ||
|
|
f0c80a8d7a | ||
|
|
8da9e7c1f6 | ||
|
|
6d7a486e5b | ||
|
|
57122c6aa3 | ||
|
|
54abd8d4d1 | ||
|
|
06283cffed | ||
|
|
27fa0e1140 | ||
|
|
533d48abdb | ||
|
|
6845cae4c9 | ||
|
|
31c9acb1fa | ||
|
|
fb5e462300 | ||
|
|
2f3abc29b1 | ||
|
|
c5c071f285 | ||
|
|
93a3ed56e7 | ||
|
|
406fc58889 | ||
|
|
cf67d084fd | ||
|
|
d4a95af14f | ||
|
|
8c8e7102c2 | ||
|
|
b6b9ea9d70 | ||
|
|
63126950bc | ||
|
|
29d63d5dea | ||
|
|
a5f8c23dee | ||
|
|
7bb4ea57c6 | ||
|
|
75dc961bcb | ||
|
|
a9a1f6ef21 | ||
|
|
aa40161f26 | ||
|
|
6efa812874 | ||
|
|
8a683f5a3c | ||
|
|
f4b0b6a93d | ||
|
|
1337c33ad3 | ||
|
|
2f6b035138 | ||
|
|
4f9ae44472 | ||
|
|
c682330852 | ||
|
|
c064257759 | ||
|
|
8a4c629576 | ||
|
|
496b02a3bc | ||
|
|
7b5efc2203 | ||
|
|
a01d44f813 | ||
|
|
63fb3a15e9 | ||
|
|
4d0837541b | ||
|
|
999809b4c7 | ||
|
|
c452edfb9f | ||
|
|
ad2cdbd8a2 | ||
|
|
f15c24bfa7 | ||
|
|
d1f653f28c | ||
|
|
244465d3a6 | ||
|
|
c6236ab70c | ||
|
|
644d5cb411 | ||
|
|
bb0a630416 | ||
|
|
2148ae9287 | ||
|
|
42d242609c | ||
|
|
fd0a52392b | ||
|
|
e64415d59a | ||
|
|
1871e0bdbf | ||
|
|
3ae9a965c2 | ||
|
|
85932e35a7 | ||
|
|
41b07a56cc | ||
|
|
54064c0cb8 | ||
|
|
68284b37fa | ||
|
|
ae5bc6f5d6 | ||
|
|
6dc16c9f54 | ||
|
|
faa9ac4e15 | ||
|
|
d0460849b0 | ||
|
|
bed3c2dd77 | ||
|
|
916ddd17d7 | ||
|
|
accfa7407f | ||
|
|
908db31e48 | ||
|
|
b70f632b26 | ||
|
|
d07a6385ab | ||
|
|
68df612fa1 | ||
|
|
3b96c79461 | ||
|
|
89bda5b983 | ||
|
|
22bff1fb22 | ||
|
|
55ba6488d1 | ||
|
|
2d78859171 | ||
|
|
3a661bac34 | ||
|
|
bb8a02de18 | ||
|
|
78155344f6 | ||
|
|
391a24b0f6 | ||
|
|
e75903389f | ||
|
|
27567052f2 | ||
|
|
6f447f7169 | ||
|
|
8b370cc182 | ||
|
|
af583d2971 | ||
|
|
0ebe8fb1bd | ||
|
|
befb629f46 | ||
|
|
874d67cb37 | ||
|
|
19f7a1295a | ||
|
|
78bd605617 | ||
|
|
b87f4e59a5 | ||
|
|
1eca4f12c8 | ||
|
|
f1de11d6bf | ||
|
|
9361ed9d70 | ||
|
|
ebabf4f7a8 | ||
|
|
606f3321f5 | ||
|
|
3970aa30fb | ||
|
|
678436e07c | ||
|
|
c620581699 | ||
|
|
c331d42ce4 | ||
|
|
1ac9b502f1 | ||
|
|
3fa478a12f | ||
|
|
2d86298b7f | ||
|
|
009cdb714c | ||
|
|
9d3f5427b4 | ||
|
|
e4b17f019a | ||
|
|
586c00bc02 | ||
|
|
0f11fda65a | ||
|
|
3e75331ef7 | ||
|
|
be133408ac | ||
|
|
7e1e0d6928 | ||
|
|
cd3d8df5a8 | ||
|
|
24d3c22017 | ||
|
|
b0d37f4e51 | ||
|
|
3559124674 | ||
|
|
6c33e02141 | ||
|
|
8cf94d602f | ||
|
|
016a6f182f | ||
|
|
6fbc019142 | ||
|
|
26f95d6a97 | ||
|
|
40f7b0d171 | ||
|
|
4904700751 | ||
|
|
83538c4b2b | ||
|
|
eb7b559529 | ||
|
|
4945465cf0 | ||
|
|
7eed7282a9 | ||
|
|
47f0781822 | ||
|
|
88b8e3e3d5 | ||
|
|
47c3ab9214 | ||
|
|
d6d436b59c | ||
|
|
6ff7057967 | ||
|
|
e032ab1179 | ||
|
|
65bddfcd93 | ||
|
|
2d3ce418dd | ||
|
|
548d72f7b9 | ||
|
|
19837a0f29 | ||
|
|
483b65a1dc | ||
|
|
b85931c7ab | ||
|
|
9225f47338 | ||
|
|
bccac5e4a6 | ||
|
|
7cb07fdc04 | ||
|
|
b137450026 | ||
|
|
dc5090469a | ||
|
|
e0ae2ace89 | ||
|
|
269faae04b | ||
|
|
e282acd41c | ||
|
|
a266668348 | ||
|
|
3bb3e142fc | ||
|
|
6ac6d70a22 | ||
|
|
b0acf33ba5 | ||
|
|
b3eb64b64c | ||
|
|
95f8ab1a29 | ||
|
|
4e043384db | ||
|
|
0f5df8ba17 | ||
|
|
2826ab48a2 | ||
|
|
7ff1b635c8 | ||
|
|
dfb5e8b5e5 | ||
|
|
7259da799c | ||
|
|
965069fce1 | ||
|
|
90232806d9 | ||
|
|
81bc153399 | ||
|
|
c63e526f99 | ||
|
|
2b74263007 | ||
|
|
d3a82f7119 | ||
|
|
291c5a0341 | ||
|
|
bcb41399ca | ||
|
|
6f0f53849b | ||
|
|
4e7d63761a | ||
|
|
198c84105d | ||
|
|
2453b9f443 | ||
|
|
b091aca986 | ||
|
|
8f02ce54a0 | ||
|
|
f4b7c63002 | ||
|
|
a4629280b5 | ||
|
|
855fb007da | ||
|
|
d805b52c1f | ||
|
|
2ea55685bb | ||
|
|
bd6ff3deaa | ||
|
|
82dd53ec88 | ||
|
|
71d749541d | ||
|
|
48a57fc4b9 | ||
|
|
530e0910fc | ||
|
|
2fdf8fc0a2 | ||
|
|
91db9c9300 | ||
|
|
bc42205593 | ||
|
|
2e3cba6416 | ||
|
|
7852aacd11 | ||
|
|
6cccd67ecd | ||
|
|
a7a89c9de1 | ||
|
|
5ca8eed89e | ||
|
|
c885c3c9a6 | ||
|
|
d81c38c350 | ||
|
|
92d5b73215 | ||
|
|
097e92db6a | ||
|
|
84c6209a45 | ||
|
|
107e48808a | ||
|
|
47168b5505 | ||
|
|
58152ec981 | ||
|
|
c74afbf332 | ||
|
|
7cdda00a54 | ||
|
|
a74282bce6 | ||
|
|
107f048c7a | ||
|
|
a2486a5f06 | ||
|
|
07ab116efb | ||
|
|
1a13af3c7a | ||
|
|
f2966a2594 | ||
|
|
58bb97e3c6 | ||
|
|
34569a2410 | ||
|
|
a84aa5c049 | ||
|
|
acfa9c87ef | ||
|
|
f245d8e429 | ||
|
|
62cf0f54e0 | ||
|
|
5f015e76ba | ||
|
|
aebcec28e0 | ||
|
|
db1c5a94f7 | ||
|
|
56222a8493 | ||
|
|
b7510ce709 | ||
|
|
5739799e2e | ||
|
|
813cf87920 | ||
|
|
c95b151daf | ||
|
|
a0f823a3cf | ||
|
|
64e0f6d688 | ||
|
|
ddd5b1087c | ||
|
|
008be9b846 | ||
|
|
8e7cabdc04 | ||
|
|
a4c4237f99 | ||
|
|
bda3740dcd | ||
|
|
5b4633baa9 | ||
|
|
96351181cb | ||
|
|
957d591d99 | ||
|
|
75f605ba1a | ||
|
|
ab898a7180 | ||
|
|
c9a4516ab1 | ||
|
|
fe97c0d5eb | ||
|
|
6056764840 | ||
|
|
8747c0dbb0 | ||
|
|
c5cdd5f9c6 | ||
|
|
abc5d53159 | ||
|
|
2f76019a89 | ||
|
|
3f45beb1ed | ||
|
|
bc1126a85b | ||
|
|
380017041e | ||
|
|
ab7cdbb7e0 | ||
|
|
e5b78d0221 | ||
|
|
1acaa6c486 | ||
|
|
b0381076b7 | ||
|
|
ffff2d6dbb | ||
|
|
afa9f07649 | ||
|
|
addb5c49ea | ||
|
|
a112d2d55b | ||
|
|
619a271c8a | ||
|
|
909f2ee36d | ||
|
|
b4cf3d9d03 | ||
|
|
e6ab6e0293 | ||
|
|
66d9c7c631 | ||
|
|
fec45f3eb6 | ||
|
|
7211d1a6fc | ||
|
|
f3069754a9 | ||
|
|
4f43152aeb | ||
|
|
7125055d02 | ||
|
|
c91a9ce390 | ||
|
|
3e7b73da2c | ||
|
|
61ac50c00d | ||
|
|
c1201f0bce | ||
|
|
acdffac5ad | ||
|
|
e420300fa4 | ||
|
|
260a5a4f9a | ||
|
|
ed0c2006fe | ||
|
|
9ffd888c86 | ||
|
|
175a9dc28d | ||
|
|
5764e4f7f2 | ||
|
|
4275a494b9 | ||
|
|
a3deb8d30d | ||
|
|
aafdb0a37b | ||
|
|
56a815719a | ||
|
|
4db26bfa3a | ||
|
|
8d84ccb12b | ||
|
|
3321d14997 | ||
|
|
43cc4684e1 | ||
|
|
afa5a4b17c | ||
|
|
33c433fe59 | ||
|
|
9cd47fa857 | ||
|
|
32d9abe802 | ||
|
|
3947d4a165 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -19,3 +19,4 @@
|
||||
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
|
||||
- [ ] _Tests added / updated (if applicable)_
|
||||
- [ ] _Documentation added / updated (if applicable)_
|
||||
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
|
||||
|
||||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Only the latest version of Invoke will receive security updates.
|
||||
We do not currently maintain multiple versions of the application with updates.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a vulnerability, contact the Invoke team directly at security@invoke.ai
|
||||
|
||||
At this time, we do not maintain a formal bug bounty program.
|
||||
|
||||
You can also share identified security issues with our team on huntr.com
|
||||
@@ -1364,7 +1364,6 @@ the in-memory loaded model:
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ If you're a new contributor to InvokeAI or Open Source Projects, this is the gui
|
||||
## New Contributor Checklist
|
||||
|
||||
- [x] Set up your local development environment & fork of InvokAI by following [the steps outlined here](../dev-environment.md)
|
||||
- [x] Set up your local tooling with [this guide](InvokeAI/contributing/LOCAL_DEVELOPMENT/#developing-invokeai-in-vscode). Feel free to skip this step if you already have tooling you're comfortable with.
|
||||
- [x] Set up your local tooling with [this guide](../LOCAL_DEVELOPMENT.md). Feel free to skip this step if you already have tooling you're comfortable with.
|
||||
- [x] Familiarize yourself with [Git](https://www.atlassian.com/git) & our project structure by reading through the [development documentation](development.md)
|
||||
- [x] Join the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord
|
||||
- [x] Choose an issue to work on! This can be achieved by asking in the #dev-chat channel, tackling a [good first issue](https://github.com/invoke-ai/InvokeAI/contribute) or finding an item on the [roadmap](https://github.com/orgs/invoke-ai/projects/7). If nothing in any of those places catches your eye, feel free to work on something of interest to you!
|
||||
|
||||
@@ -17,46 +17,49 @@ If you just want to use Invoke, you should use the [installer][installer link].
|
||||
## Setup
|
||||
|
||||
1. Run through the [requirements][requirements link].
|
||||
1. [Fork and clone][forking link] the [InvokeAI repo][repo link].
|
||||
1. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
1. Create a python virtual environment inside the directory you just created:
|
||||
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
|
||||
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
|
||||
4. Create a python virtual environment inside the directory you just created:
|
||||
|
||||
```sh
|
||||
python3 -m venv .venv --prompt InvokeAI-Dev
|
||||
```
|
||||
```sh
|
||||
python3 -m venv .venv --prompt InvokeAI-Dev
|
||||
```
|
||||
|
||||
1. Activate the venv (you'll need to do this every time you want to run the app):
|
||||
5. Activate the venv (you'll need to do this every time you want to run the app):
|
||||
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
1. Install the repo as an [editable install][editable install link]:
|
||||
6. Install the repo as an [editable install][editable install link]:
|
||||
|
||||
```sh
|
||||
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
```sh
|
||||
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
Refer to the [manual installation][manual install link]] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
Refer to the [manual installation][manual install link]] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
|
||||
1. Install the frontend dev toolchain:
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (recommend v20 LTS)
|
||||
- [`pnpm`](https://pnpm.io/installation#installing-a-specific-version) (must be v8 - not v9!)
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
1. Do a production build of the frontend:
|
||||
8. Do a production build of the frontend:
|
||||
|
||||
```sh
|
||||
pnpm build
|
||||
```
|
||||
```sh
|
||||
cd PATH_TO_INVOKEAI_REPO/invokeai/frontend/web
|
||||
pnpm i
|
||||
pnpm build
|
||||
```
|
||||
|
||||
1. Start the application:
|
||||
9. Start the application:
|
||||
|
||||
```sh
|
||||
python scripts/invokeai-web.py
|
||||
```
|
||||
```sh
|
||||
cd PATH_TO_INVOKEAI_REPO
|
||||
python scripts/invokeai-web.py
|
||||
```
|
||||
|
||||
1. Access the UI at `localhost:9090`.
|
||||
10. Access the UI at `localhost:9090`.
|
||||
|
||||
## Updating the UI
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ This project is a combined effort of dedicated people from across the world. [C
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/docs/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
|
||||
By making a contribution to this project, you certify that:
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
button in the upper right of your browser window. This will convert the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
|
||||
@@ -97,16 +97,16 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
sudo pacman -S --needed base-devel
|
||||
```
|
||||
|
||||
2. Install `opencv` and `blas`:
|
||||
2. Install `opencv`, `blas`, and required dependencies:
|
||||
|
||||
```sh
|
||||
sudo pacman -S opencv blas
|
||||
sudo pacman -S opencv blas fmt glew vtk hdf5
|
||||
```
|
||||
|
||||
or for CUDA support
|
||||
|
||||
```sh
|
||||
sudo pacman -S opencv-cuda blas
|
||||
sudo pacman -S opencv-cuda blas fmt glew vtk hdf5
|
||||
```
|
||||
|
||||
3. Fix the naming of the `opencv` package configuration file:
|
||||
|
||||
@@ -99,7 +99,6 @@ their descriptions.
|
||||
| Scale Latents | Scales latents by a given factor. |
|
||||
| Segment Anything Processor | Applies segment anything processing to image |
|
||||
| Show Image | Displays a provided image, and passes it forward in the pipeline. |
|
||||
| Step Param Easing | Experimental per-step parameter easing for denoising steps |
|
||||
| String Primitive Collection | A collection of string primitive values |
|
||||
| String Primitive | A string primitive value |
|
||||
| Subtract Integers | Subtracts two numbers |
|
||||
|
||||
@@ -259,7 +259,7 @@ def select_gpu() -> GpuType:
|
||||
[
|
||||
f"Detected the [gold1]{OS}-{ARCH}[/] platform",
|
||||
"",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/#system[/] to ensure your system meets the minimum requirements.",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/installation/requirements/[/] to ensure your system meets the minimum requirements.",
|
||||
"",
|
||||
"[red3]🠶[/] [b]Your GPU drivers must be correctly installed before using InvokeAI![/] [red3]🠴[/]",
|
||||
]
|
||||
|
||||
@@ -68,7 +68,7 @@ do_line_input() {
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest\n\n"
|
||||
read -p "Please enter 1-4, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
|
||||
@@ -40,6 +40,8 @@ class AppVersion(BaseModel):
|
||||
|
||||
version: str = Field(description="App version")
|
||||
|
||||
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
@@ -10,6 +11,7 @@ from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import huggingface_hub
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.routing import APIRouter
|
||||
@@ -27,6 +29,7 @@ from invokeai.app.services.model_records import (
|
||||
ModelRecordChanges,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -34,7 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
@@ -808,7 +811,11 @@ def get_is_installed(
|
||||
for model in installed_models:
|
||||
if model.source == starter_model.source:
|
||||
return True
|
||||
if model.name == starter_model.name and model.base == starter_model.base and model.type == starter_model.type:
|
||||
if (
|
||||
(model.name == starter_model.name or model.name in starter_model.previous_names)
|
||||
and model.base == starter_model.base
|
||||
and model.type == starter_model.type
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -919,3 +926,51 @@ async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
|
||||
class HFTokenStatus(str, Enum):
|
||||
VALID = "valid"
|
||||
INVALID = "invalid"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def set_token(cls, token: str) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
@@ -110,7 +110,7 @@ async def cancel_by_batch_ids(
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -62,6 +63,7 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
|
||||
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
|
||||
- `Internal`: The invocation is not intended for use by end-users. It may be changed or removed at any time, but is exposed for users to play with.
|
||||
- `Special`: The invocation is a special case and does not fit into any of the other classifications.
|
||||
"""
|
||||
|
||||
Stable = "stable"
|
||||
@@ -69,6 +71,7 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Prototype = "prototype"
|
||||
Deprecated = "deprecated"
|
||||
Internal = "internal"
|
||||
Special = "special"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
@@ -192,12 +195,19 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def invalidate_typeadapter(cls) -> None:
|
||||
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
|
||||
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
|
||||
the updated allowlist and denylist."""
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
@@ -479,6 +489,26 @@ def invocation(
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
# Validate the `invoke()` method is implemented
|
||||
if "invoke" in cls.__abstractmethods__:
|
||||
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
|
||||
|
||||
# And validate that `invoke()` returns a subclass of `BaseInvocationOutput
|
||||
invoke_return_annotation = signature(cls.invoke).return_annotation
|
||||
|
||||
try:
|
||||
# TODO(psyche): If `invoke()` is not defined, `return_annotation` ends up as the string "BaseInvocationOutput"
|
||||
# instead of the class `BaseInvocationOutput`. This may be a pydantic bug: https://github.com/pydantic/pydantic/issues/7978
|
||||
if isinstance(invoke_return_annotation, str):
|
||||
invoke_return_annotation = getattr(sys.modules[cls.__module__], invoke_return_annotation)
|
||||
|
||||
assert invoke_return_annotation is not BaseInvocationOutput
|
||||
assert issubclass(invoke_return_annotation, BaseInvocationOutput)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invocation "{invocation_type}" must have a return annotation of a subclass of BaseInvocationOutput (got "{invoke_return_annotation}")'
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
|
||||
@@ -1,98 +1,120 @@
|
||||
from typing import Any, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, LatentsField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def slerp(
|
||||
t: Union[float, np.ndarray],
|
||||
v0: Union[torch.Tensor, np.ndarray],
|
||||
v1: Union[torch.Tensor, np.ndarray],
|
||||
device: torch.device,
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
|
||||
|
||||
@invocation(
|
||||
"lblend",
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
tags=["latents", "blend", "mask"],
|
||||
category="latents",
|
||||
version="1.0.3",
|
||||
version="1.1.0",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
"""Blend two latents using a given alpha. If a mask is provided, the second latents will be masked before blending.
|
||||
Latents must have same size. Masking functionality added by @dwringer."""
|
||||
|
||||
latents_a: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents_b: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
latents_a: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
latents_b: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask for blending in latents B")
|
||||
alpha: float = InputField(ge=0, default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
return mask_tensor
|
||||
|
||||
def replace_tensor_from_masked_tensor(
|
||||
self, tensor: torch.Tensor, other_tensor: torch.Tensor, mask_tensor: torch.Tensor
|
||||
):
|
||||
output = tensor.clone()
|
||||
mask_tensor = mask_tensor.expand(output.shape)
|
||||
if output.dtype != torch.float16:
|
||||
output = torch.add(output, mask_tensor * torch.sub(other_tensor, tensor))
|
||||
else:
|
||||
output = torch.add(output, mask_tensor.half() * torch.sub(other_tensor, tensor))
|
||||
return output
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||
if self.mask is None:
|
||||
mask_tensor = torch.zeros(latents_a.shape[-2:])
|
||||
else:
|
||||
mask_tensor = self.prep_mask_tensor(context.images.get_pil(self.mask.image_name))
|
||||
mask_tensor = tv_resize(mask_tensor, latents_a.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
|
||||
latents_b = self.replace_tensor_from_masked_tensor(latents_b, latents_a, mask_tensor)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
raise ValueError("Latents to blend must be the same size.")
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b, device)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
@@ -95,6 +95,7 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_manager,
|
||||
),
|
||||
):
|
||||
context.util.signal_progress("Building conditioning")
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
compel = Compel(
|
||||
@@ -191,6 +192,7 @@ class SDXLPromptInvocationBase:
|
||||
ti_manager,
|
||||
),
|
||||
):
|
||||
context.util.signal_progress("Building conditioning")
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
|
||||
|
||||
1563
invokeai/app/invocations/composition-nodes.py
Normal file
1563
invokeai/app/invocations/composition-nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -65,6 +65,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
|
||||
@@ -131,6 +131,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
@@ -510,6 +511,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
bgr_mode: bool = False,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
@@ -519,6 +521,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
if bgr_mode: # SDXL t2i trained on cv2's BGR outputs, but PIL won't convert straight to BGR
|
||||
r, g, b = image.split()
|
||||
image = Image.merge("RGB", (b, g, r))
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
@@ -616,13 +622,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
|
||||
# SDXL adapters are trained on cv2's BGR outputs
|
||||
r, g, b = image.split()
|
||||
image = Image.merge("RGB", (b, g, r))
|
||||
else:
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
||||
|
||||
@@ -630,29 +640,39 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
with t2i_adapter_loaded_model as t2i_adapter_model:
|
||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||
# T2I-Adapter model.
|
||||
#
|
||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||
# of the same requirements (e.g. preserving binary masks during resize).
|
||||
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
_, _, latent_height, latent_width = latents_shape
|
||||
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||
t2i_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=t2i_input_width,
|
||||
height=t2i_input_height,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
# We crop the image to this size so that the positions match the input image on non-standard resolutions
|
||||
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||
if t2i_image.shape[2] > t2i_input_height or t2i_image.shape[3] > t2i_input_width:
|
||||
t2i_image = t2i_image[
|
||||
:, :, : min(t2i_image.shape[2], t2i_input_height), : min(t2i_image.shape[3], t2i_input_width)
|
||||
]
|
||||
|
||||
adapter_state = t2i_adapter_model(t2i_image)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -900,7 +920,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
bgr_mode = self.unet.unet.base == BaseModelType.StableDiffusionXL
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager, bgr_mode)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
@@ -131,6 +134,7 @@ class FieldDescriptions:
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
clip_g_model = "CLIP-G Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
mmditx = "MMDiTX"
|
||||
@@ -246,6 +250,17 @@ class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
|
||||
@@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
@@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
@@ -56,7 +58,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.0",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -81,15 +83,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: FluxConditioningField | None = InputField(
|
||||
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
@@ -138,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
return t5_embeddings, clip_embeddings
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
neg_t5_embeddings: torch.Tensor | None = None
|
||||
neg_clip_embeddings: torch.Tensor | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
@@ -182,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.positive_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.negative_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
image_seq_len=packed_h * packed_w,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
@@ -207,9 +216,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
if self.add_noise:
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
x = init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
@@ -224,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
|
||||
pos_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
neg_txt_ids: torch.Tensor | None = None
|
||||
if neg_t5_embeddings is not None:
|
||||
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
|
||||
neg_txt_ids = torch.zeros(
|
||||
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
|
||||
# packed_w correctly.
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
@@ -334,12 +335,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=pos_t5_embeddings,
|
||||
txt_ids=pos_txt_ids,
|
||||
vec=pos_clip_embeddings,
|
||||
neg_txt=neg_t5_embeddings,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
neg_vec=neg_clip_embeddings,
|
||||
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
||||
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
@@ -353,6 +350,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: FluxConditioningField | list[FluxConditioningField],
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[FluxTextConditioning]:
|
||||
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
||||
# Normalize to a list of FluxConditioningFields.
|
||||
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[FluxTextConditioning] = []
|
||||
for cond_field in cond_list:
|
||||
# Load the text embeddings.
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if cond_field.mask is not None:
|
||||
mask = context.tensors.load(cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
|
||||
@@ -11,7 +11,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, SubModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -22,7 +29,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -41,7 +48,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
@@ -54,7 +64,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
@@ -71,6 +83,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
@@ -111,6 +124,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
|
||||
@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
@@ -53,6 +54,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
context.util.signal_progress("Running VAE")
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
@@ -44,9 +44,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@@ -60,6 +59,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
59
invokeai/app/invocations/image_panels.py
Normal file
59
invokeai/app/invocations/image_panels.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("image_panel_coordinate_output")
|
||||
class ImagePanelCoordinateOutput(BaseInvocationOutput):
|
||||
x_left: int = OutputField(description="The left x-coordinate of the panel.")
|
||||
y_top: int = OutputField(description="The top y-coordinate of the panel.")
|
||||
width: int = OutputField(description="The width of the panel.")
|
||||
height: int = OutputField(description="The height of the panel.")
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_panel_layout",
|
||||
title="Image Panel Layout",
|
||||
tags=["image", "panel", "layout"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ImagePanelLayoutInvocation(BaseInvocation):
|
||||
"""Get the coordinates of a single panel in a grid. (If the full image shape cannot be divided evenly into panels,
|
||||
then the grid may not cover the entire image.)
|
||||
"""
|
||||
|
||||
width: int = InputField(description="The width of the entire grid.")
|
||||
height: int = InputField(description="The height of the entire grid.")
|
||||
num_cols: int = InputField(ge=1, default=1, description="The number of columns in the grid.")
|
||||
num_rows: int = InputField(ge=1, default=1, description="The number of rows in the grid.")
|
||||
panel_col_idx: int = InputField(ge=0, default=0, description="The column index of the panel to be processed.")
|
||||
panel_row_idx: int = InputField(ge=0, default=0, description="The row index of the panel to be processed.")
|
||||
|
||||
@field_validator("panel_col_idx")
|
||||
def validate_panel_col_idx(cls, v: int, info: ValidationInfo) -> int:
|
||||
if v < 0 or v >= info.data["num_cols"]:
|
||||
raise ValueError(f"panel_col_idx must be between 0 and {info.data['num_cols'] - 1}")
|
||||
return v
|
||||
|
||||
@field_validator("panel_row_idx")
|
||||
def validate_panel_row_idx(cls, v: int, info: ValidationInfo) -> int:
|
||||
if v < 0 or v >= info.data["num_rows"]:
|
||||
raise ValueError(f"panel_row_idx must be between 0 and {info.data['num_rows'] - 1}")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImagePanelCoordinateOutput:
|
||||
x_left = self.panel_col_idx * (self.width // self.num_cols)
|
||||
y_top = self.panel_row_idx * (self.height // self.num_rows)
|
||||
width = self.width // self.num_cols
|
||||
height = self.height // self.num_rows
|
||||
return ImagePanelCoordinateOutput(x_left=x_left, y_top=y_top, width=width, height=height)
|
||||
@@ -117,6 +117,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
@@ -60,6 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE decoder")
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
||||
@@ -165,6 +165,7 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
mask: TensorField = InputField(description="The mask tensor to apply.")
|
||||
image: ImageField = InputField(description="The image to apply the mask to.")
|
||||
invert: bool = InputField(default=False, description="Whether to invert the mask.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGBA")
|
||||
@@ -179,6 +180,9 @@ class ApplyMaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
mask = mask > 0.5
|
||||
mask_np = (mask.float() * 255).byte().cpu().numpy().astype(np.uint8)
|
||||
|
||||
if self.invert:
|
||||
mask_np = 255 - mask_np
|
||||
|
||||
# Apply the mask only to the alpha channel where the original alpha is non-zero. This preserves the original
|
||||
# image's transparency - else the transparent regions would end up as opaque black.
|
||||
|
||||
|
||||
@@ -147,6 +147,10 @@ GENERATION_MODES = Literal[
|
||||
"flux_img2img",
|
||||
"flux_inpaint",
|
||||
"flux_outpaint",
|
||||
"sd3_txt2img",
|
||||
"sd3_img2img",
|
||||
"sd3_inpaint",
|
||||
"sd3_outpaint",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,43 +1,4 @@
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from easing_functions import (
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseInOut,
|
||||
CubicEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseOut,
|
||||
LinearInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseInOut,
|
||||
QuadEaseOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseOut,
|
||||
SineEaseIn,
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
@@ -65,191 +26,3 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation(
|
||||
"step_param_easing",
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.2",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("start_step: " + str(start_step))
|
||||
context.logger.debug("end_step: " + str(end_step))
|
||||
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.logger.debug("prelist: " + str(prelist))
|
||||
context.logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
context.logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = []
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1,
|
||||
)
|
||||
base_easing_vals = []
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1,
|
||||
)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
@@ -4,7 +4,13 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
@@ -18,6 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
@@ -426,6 +433,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
|
||||
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
@@ -521,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
338
invokeai/app/invocations/sd3_denoise.py
Normal file
338
invokeai/app/invocations/sd3_denoise.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_denoise",
|
||||
title="SD3 Denoise",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
|
||||
inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
||||
represent the regions to be preserved.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
||||
# 1.0 represents the regions to be preserved.
|
||||
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
joint_attention_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
sd3_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
|
||||
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
|
||||
|
||||
t5_embeds = sd3_conditioning.t5_embeds
|
||||
if t5_embeds is None:
|
||||
t5_embeds = torch.zeros(
|
||||
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat(
|
||||
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
num_samples: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
"""Prepare the CFG scale list.
|
||||
|
||||
Args:
|
||||
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
||||
on the scheduler used (e.g. higher order schedulers).
|
||||
|
||||
Returns:
|
||||
list[float]: _description_
|
||||
"""
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the timestep schedule.
|
||||
# We add an extra step to the end to account for the final timestep of 0.0.
|
||||
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
noise = self._get_noise(
|
||||
num_samples=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# Noise the init_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
latents = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return latents
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
with transformer_info.model_on_device() as (cached_weights, transformer):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (t_prev - t_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
|
||||
|
||||
return step_callback
|
||||
65
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
65
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_i2l",
|
||||
title="SD3 Image to Latents",
|
||||
tags=["image", "latents", "vae", "i2l", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
# TODO: Use seed to make sampling reproducible.
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
74
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
74
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_l2i",
|
||||
title="SD3 Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE")
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(vae.device)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -8,14 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, SubModelType
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SD3 base model loader output."""
|
||||
|
||||
mmditx: TransformerField = OutputField(description=FieldDescriptions.mmditx, title="MMDiTX")
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
|
||||
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
@@ -33,68 +35,72 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
# TODO(ryand): Create a UIType.Sd3MainModelField to use here.
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.MainModel,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
# TODO(ryand): Make the text encoders optional.
|
||||
# Note: The text encoders are optional for SD3. The model was trained with dropout, so any can be left out at
|
||||
# inference time. Typically, only the T5 encoder is omitted, since it is the largest by far.
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_l_embed_model: ModelIdentifierField = InputField(
|
||||
clip_l_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
ui_type=UIType.CLIPLEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Embed",
|
||||
title="CLIP L Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_g_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
clip_g_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_g_model,
|
||||
ui_type=UIType.CLIPGEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Embed",
|
||||
title="CLIP G Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
# TODO(ryand): Create a UIType.Sd3VaModelField to use here.
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE"
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
for key in [
|
||||
self.model.key,
|
||||
self.t5_encoder_model.key,
|
||||
self.clip_l_embed_model.key,
|
||||
self.clip_g_embed_model.key,
|
||||
self.vae_model.key,
|
||||
]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
# TODO(ryand): Figure out the sub-model types for SD3.
|
||||
mmditx = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer_l = self.clip_l_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = self.clip_l_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer_g = self.clip_g_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_g = self.clip_g_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer_t5 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(mmditx)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = (
|
||||
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
if self.vae_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
)
|
||||
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = (
|
||||
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
if self.clip_l_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
)
|
||||
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
clip_encoder_g = (
|
||||
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
if self.clip_g_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
)
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
mmditx=TransformerField(transformer=mmditx, loras=[]),
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
|
||||
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
title="SD3 Text Encoding",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
clip_l: CLIPField = InputField(
|
||||
title="CLIP L",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
clip_g: CLIPField = InputField(
|
||||
title="CLIP G",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
title="T5Encoder",
|
||||
default=None,
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
|
||||
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
|
||||
|
||||
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
|
||||
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
|
||||
|
||||
t5_embeddings: torch.Tensor | None = None
|
||||
if self.t5_encoder is not None:
|
||||
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SD3ConditioningInfo(
|
||||
clip_l_embeds=clip_l_embeddings,
|
||||
clip_l_pooled_embeds=clip_l_pooled_embeddings,
|
||||
clip_g_embeds=clip_g_embeddings,
|
||||
clip_g_pooled_embeds=clip_g_pooled_embeddings,
|
||||
t5_embeds=t5_embeddings,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return SD3ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
text_inputs = t5_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
|
||||
|
||||
text_inputs = clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_point_lists_or_bounding_box(self):
|
||||
if self.point_lists is None and self.bounding_boxes is None:
|
||||
raise ValueError("Either point_lists or bounding_box must be provided.")
|
||||
elif self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
|
||||
@@ -20,7 +20,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
NodeExecutionStatsSummary,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
|
||||
# Size of 1GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
@@ -24,7 +24,7 @@ class ModelLoadServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load import (
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -30,7 +30,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
@@ -45,7 +45,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -78,15 +78,14 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
@@ -109,5 +108,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
self._ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache)
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBas
|
||||
from invokeai.app.services.model_load.model_load_default import ModelLoadService
|
||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from pydantic import (
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
BatchDataType = Union[StrictStr, float, int, ImageField]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
@@ -159,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ImagesInterface(InvocationContextInterface):
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: Image,
|
||||
@@ -185,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
self._util.signal_progress("Saving image")
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
@@ -221,7 +228,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||
"""Gets an image as a PIL Image object.
|
||||
"""Gets an image as a PIL Image object. This method returns a copy of the image.
|
||||
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
@@ -233,11 +240,15 @@ class ImagesInterface(InvocationContextInterface):
|
||||
image = self._services.images.get_pil_image(image_name)
|
||||
if mode and mode != image.mode:
|
||||
try:
|
||||
# convert makes a copy!
|
||||
image = image.convert(mode)
|
||||
except ValueError:
|
||||
self._services.logger.warning(
|
||||
f"Could not convert image from {image.mode} to {mode}. Using original mode instead."
|
||||
)
|
||||
else:
|
||||
# copy the image to prevent the user from modifying the original
|
||||
image = image.copy()
|
||||
return image
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
@@ -290,15 +301,15 @@ class TensorsInterface(InvocationContextInterface):
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""Loads a tensor by name.
|
||||
"""Loads a tensor by name. This method returns a copy of the tensor.
|
||||
|
||||
Args:
|
||||
name: The name of the tensor to load.
|
||||
|
||||
Returns:
|
||||
The loaded tensor.
|
||||
The tensor.
|
||||
"""
|
||||
return self._services.tensors.load(name)
|
||||
return self._services.tensors.load(name).clone()
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
@@ -316,21 +327,25 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""Loads conditioning data by name.
|
||||
"""Loads conditioning data by name. This method returns a copy of the conditioning data.
|
||||
|
||||
Args:
|
||||
name: The name of the conditioning data to load.
|
||||
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
The conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
return deepcopy(self._services.conditioning.load(name))
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Check if a model exists.
|
||||
|
||||
@@ -363,11 +378,15 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
|
||||
message = f"Loading model {model.name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@@ -392,6 +411,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
message = f"Loading model {name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
@@ -462,6 +485,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
self._util.signal_progress(f"Downloading model {source}")
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_local_model(
|
||||
@@ -484,6 +508,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
self._util.signal_progress(f"Loading model {model_path.name}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
@@ -509,6 +535,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
|
||||
self._util.signal_progress(f"Loading model {source}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
@@ -702,12 +730,12 @@ def build_invocation_context(
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
images = ImagesInterface(services=services, data=data)
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data, util=util)
|
||||
images = ImagesInterface(services=services, data=data, util=util)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
{
|
||||
"name": "SD3.5 Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample text to image workflow for Stable Diffusion 3.5",
|
||||
"version": "1.0.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, SD3.5, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"fieldName": "prompt"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "sd3_model_loader",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
|
||||
"hash": "placeholder",
|
||||
"name": "stable-diffusion-3.5-medium",
|
||||
"base": "sd-3",
|
||||
"type": "main"
|
||||
}
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_l_model": {
|
||||
"name": "clip_l_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g_model": {
|
||||
"name": "clip_g_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": -55.58689609637031,
|
||||
"y": -111.53602444662268
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 470.45870147220353,
|
||||
"y": 350.3141781644303
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "sd3_l2i",
|
||||
"version": "1.3.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1192.3097009334897,
|
||||
"y": -366.0994675072209
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 408.16054647924784,
|
||||
"y": 65.06415352118786
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 378.9283412440941,
|
||||
"y": -302.65777497352553
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "sd3_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_conditioning": {
|
||||
"name": "positive_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"name": "negative_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"cfg_scale": {
|
||||
"name": "cfg_scale",
|
||||
"label": "",
|
||||
"value": 3.5
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"steps": {
|
||||
"name": "steps",
|
||||
"label": "",
|
||||
"value": 30
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 813.7814762740603,
|
||||
"y": -142.20529727605867
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
|
||||
"type": "default",
|
||||
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
|
||||
"type": "default",
|
||||
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
|
||||
"type": "default",
|
||||
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
|
||||
"type": "default",
|
||||
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
SD3_5_LATENT_RGB_FACTORS = [
|
||||
[-0.05240681, 0.03251581, 0.0749016],
|
||||
[-0.0580572, 0.00759826, 0.05729818],
|
||||
[0.16144888, 0.01270368, -0.03768577],
|
||||
[0.14418615, 0.08460266, 0.15941818],
|
||||
[0.04894035, 0.0056485, -0.06686988],
|
||||
[0.05187166, 0.19222395, 0.06261094],
|
||||
[0.1539433, 0.04818359, 0.07103094],
|
||||
[-0.08601796, 0.09013458, 0.10893912],
|
||||
[-0.12398469, -0.06766567, 0.0033688],
|
||||
[-0.0439737, 0.07825329, 0.02258823],
|
||||
[0.03101129, 0.06382551, 0.07753657],
|
||||
[-0.01315361, 0.08554491, -0.08772475],
|
||||
[0.06464487, 0.05914605, 0.13262741],
|
||||
[-0.07863674, -0.02261737, -0.12761454],
|
||||
[-0.09923835, -0.08010759, -0.06264447],
|
||||
[-0.03392309, -0.0804029, -0.06078822],
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
elif base_model == BaseModelType.StableDiffusion3:
|
||||
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
|
||||
class CustomDoubleStreamBlockProcessor:
|
||||
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
|
||||
|
||||
@staticmethod
|
||||
def _double_stream_block_forward(
|
||||
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
|
||||
block: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
||||
values.
|
||||
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
||||
block, img, txt, vec, pe, attn_mask=attn_mask
|
||||
)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class CustomSingleStreamBlockProcessor:
|
||||
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
|
||||
etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _single_stream_block_forward(
|
||||
block: SingleStreamBlock,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function is a direct copy of SingleStreamBlock.forward()."""
|
||||
mod, _ = block.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
q, k = block.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
@staticmethod
|
||||
def custom_single_block_forward(
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: SingleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> torch.Tensor:
|
||||
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
||||
- Masking
|
||||
"""
|
||||
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
||||
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -18,14 +19,8 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
pos_regional_prompting_extension: RegionalPromptingExtension,
|
||||
neg_regional_prompting_extension: RegionalPromptingExtension | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
@@ -61,9 +56,9 @@ def denoise(
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
@@ -78,9 +73,9 @@ def denoise(
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -88,6 +83,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
@@ -97,15 +93,15 @@ def denoise(
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -113,6 +109,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
|
||||
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
class RegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with FLUX.
|
||||
|
||||
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
restricted_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.restricted_attn_mask = restricted_attn_mask
|
||||
|
||||
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
|
||||
"""Create a RegionalPromptingExtension from a list of text conditionings.
|
||||
|
||||
Args:
|
||||
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
|
||||
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
|
||||
"""
|
||||
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
|
||||
regional_text_conditioning, img_seq_len
|
||||
)
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
|
||||
)
|
||||
|
||||
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
|
||||
#
|
||||
# @classmethod
|
||||
# def _prepare_unrestricted_attn_mask(
|
||||
# cls,
|
||||
# regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
# img_seq_len: int,
|
||||
# ) -> torch.Tensor:
|
||||
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
|
||||
# - img self-attention is not masked.
|
||||
# - img regions attend to both txt within their own region and to global prompts.
|
||||
# """
|
||||
# device = TorchDevice.choose_torch_device()
|
||||
|
||||
# # Infer txt_seq_len from the t5_embeddings tensor.
|
||||
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# # Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# # There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# # 1. txt attends to itself
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # 4. regional img attends to itself
|
||||
|
||||
# # Initialize empty attention mask.
|
||||
# regional_attention_mask = torch.zeros(
|
||||
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
# )
|
||||
|
||||
# for image_mask, t5_embedding_range in zip(
|
||||
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
# ):
|
||||
# # 1. txt attends to itself
|
||||
# regional_attention_mask[
|
||||
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
# ] = 1.0
|
||||
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
|
||||
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
|
||||
|
||||
# # 4. regional img attends to itself
|
||||
# # Allow unrestricted img self attention.
|
||||
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
|
||||
# # Convert attention mask to boolean.
|
||||
# regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
# return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _prepare_restricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
|
||||
- img self-attention is only allowed within regions.
|
||||
- img regions only attend to txt within their own region, not to global prompts.
|
||||
"""
|
||||
# Identify background region. I.e. the region that is not covered by any region masks.
|
||||
background_region_mask: None | torch.Tensor = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(image_mask)
|
||||
background_region_mask *= 1 - image_mask
|
||||
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, short-circuit and return None.
|
||||
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
|
||||
# is a rare use case and would make the logic here significantly more complicated.
|
||||
return None
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
for image_mask, t5_embedding_range in zip(
|
||||
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
):
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is not None:
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
image_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
image_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
else:
|
||||
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
|
||||
# that regions focus on their local prompts. We do, however, allow attention between background regions
|
||||
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
|
||||
# embeddings, which we found experimentally to cause artifacts.
|
||||
|
||||
# 2. global txt attends to background region
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
background_region_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. background region attends to global txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
background_region_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# Allow background regions to attend to themselves.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_regional_text_conditioning(
|
||||
cls,
|
||||
text_conditionings: list[FluxTextConditioning],
|
||||
) -> FluxRegionalTextConditioning:
|
||||
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
|
||||
concat_t5_embeddings: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
|
||||
# Choose global CLIP embedding.
|
||||
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
|
||||
# the first prompt's CLIP embedding.
|
||||
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
|
||||
for text_conditioning in text_conditionings:
|
||||
if text_conditioning.mask is None:
|
||||
global_clip_embedding = text_conditioning.clip_embeddings
|
||||
break
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
|
||||
concat_t5_embedding_ranges.append(
|
||||
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
|
||||
)
|
||||
|
||||
image_masks.append(text_conditioning.mask)
|
||||
|
||||
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
|
||||
|
||||
# Initialize the txt_ids tensor.
|
||||
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
|
||||
t5_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
clip_embeddings=global_clip_embedding,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=image_masks,
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
|
||||
# Flatten the height and width dimensions into a single image_seq_len dimension.
|
||||
return resized_mask.flatten(start_dim=2)
|
||||
@@ -41,10 +41,12 @@ def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Te
|
||||
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
|
||||
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
|
||||
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
|
||||
clip_extra_context_tokens = state_dict["ip_adapter_proj_model.proj.weight"].shape[0] // context_dim
|
||||
|
||||
return XlabsIpAdapterParams(
|
||||
num_double_blocks=num_double_blocks,
|
||||
context_dim=context_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
clip_embeddings_dim=clip_embeddings_dim,
|
||||
clip_extra_context_tokens=clip_extra_context_tokens,
|
||||
)
|
||||
|
||||
@@ -31,13 +31,16 @@ class XlabsIpAdapterParams:
|
||||
hidden_dim: int
|
||||
|
||||
clip_embeddings_dim: int
|
||||
clip_extra_context_tokens: int
|
||||
|
||||
|
||||
class XlabsIpAdapterFlux(torch.nn.Module):
|
||||
def __init__(self, params: XlabsIpAdapterParams):
|
||||
super().__init__()
|
||||
self.image_proj = ImageProjModel(
|
||||
cross_attention_dim=params.context_dim, clip_embeddings_dim=params.clip_embeddings_dim
|
||||
cross_attention_dim=params.context_dim,
|
||||
clip_embeddings_dim=params.clip_embeddings_dim,
|
||||
clip_extra_context_tokens=params.clip_extra_context_tokens,
|
||||
)
|
||||
self.ip_adapter_double_blocks = IPAdapterDoubleBlocks(
|
||||
num_double_blocks=params.num_double_blocks, context_dim=params.context_dim, hidden_dim=params.hidden_dim
|
||||
|
||||
@@ -5,10 +5,10 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -24,12 +24,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.view(*xq.shape), xk_out.view(*xk.shape)
|
||||
|
||||
@@ -5,7 +5,11 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
CustomSingleStreamBlockProcessor,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -95,6 +99,7 @@ class Flux(nn.Module):
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -117,7 +122,6 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
assert isinstance(block, DoubleStreamBlock)
|
||||
|
||||
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
@@ -128,6 +132,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
ip_adapter_extensions=ip_adapter_extensions,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
@@ -140,7 +145,17 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
assert isinstance(block, SingleStreamBlock)
|
||||
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img=img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
36
invokeai/backend/flux/text_conditioning.py
Normal file
36
invokeai/backend/flux/text_conditioning.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxTextConditioning:
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
# If mask is None, the prompt is a global prompt.
|
||||
mask: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
# Shape: (1, concatenated_txt_seq_len, 4096)
|
||||
t5_embeddings: torch.Tensor
|
||||
# Shape: (1, concatenated_txt_seq_len, 3)
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# Global CLIP embeddings.
|
||||
# Shape: (1, 768)
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
|
||||
# global prompt.
|
||||
# image_masks[i] is the mask for the ith prompt.
|
||||
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
|
||||
image_masks: list[torch.Tensor | None]
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
BIN
invokeai/backend/image_util/assets/CIELab_to_UPLab.icc
Normal file
BIN
invokeai/backend/image_util/assets/CIELab_to_UPLab.icc
Normal file
Binary file not shown.
1020
invokeai/backend/image_util/composition.py
Normal file
1020
invokeai/backend/image_util/composition.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -45,8 +45,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
num_single_layers = 38
|
||||
# inner_dim = 3072
|
||||
# mlp_ratio = 4.0
|
||||
hidden_size = 3072
|
||||
mlp_ratio = 4.0
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
|
||||
@@ -62,30 +63,43 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||
def add_qkv_lora_layer_if_present(
|
||||
src_keys: list[str],
|
||||
src_weight_shapes: list[tuple[int, int]],
|
||||
dst_qkv_key: str,
|
||||
allow_missing_keys: bool = False,
|
||||
) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
"""
|
||||
# We expect that either all src keys are present or none of them are. Verify this.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
assert all(keys_present) or not any(keys_present)
|
||||
|
||||
# If none of the keys are present, return early.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
|
||||
src_layer_dict = grouped_state_dict.pop(src_key, None)
|
||||
if src_layer_dict is not None:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
|
||||
values = {
|
||||
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
|
||||
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
|
||||
}
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
@@ -118,6 +132,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
|
||||
f"double_blocks.{i}.img_attn.qkv",
|
||||
)
|
||||
add_qkv_lora_layer_if_present(
|
||||
@@ -126,6 +141,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
|
||||
f"double_blocks.{i}.txt_attn.qkv",
|
||||
)
|
||||
|
||||
@@ -175,7 +191,14 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"single_transformer_blocks.{i}.attn.to_v",
|
||||
f"single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
[
|
||||
(hidden_size, hidden_size),
|
||||
(hidden_size, hidden_size),
|
||||
(hidden_size, hidden_size),
|
||||
(mlp_hidden_dim, hidden_size),
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
allow_missing_keys=True,
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
|
||||
@@ -53,8 +53,7 @@ class BaseModelType(str, Enum):
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
# TODO(ryand): Should this just be StableDiffusion3?
|
||||
StableDiffusion35 = "sd-3.5"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
@@ -85,8 +84,10 @@ class SubModelType(str, Enum):
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
@@ -94,6 +95,13 @@ class SubModelType(str, Enum):
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
@@ -149,6 +157,17 @@ class ModelSourceType(str, Enum):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
variant: AnyVariant = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||
@@ -195,6 +214,9 @@ class ModelConfigBase(BaseModel):
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
@@ -337,7 +359,7 @@ class MainConfigBase(ModelConfigBase):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
@@ -421,12 +443,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.G
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -503,6 +546,8 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
|
||||
@@ -5,7 +5,6 @@ Base class for model loading in InvokeAI.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
@@ -18,19 +17,17 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
This is a context manager object that has two distinct APIs:
|
||||
|
||||
1. Older API (deprecated):
|
||||
Use the LoadedModel object directly as a context manager.
|
||||
It will move the model into VRAM (on CUDA devices), and
|
||||
Use the LoadedModel object directly as a context manager. It will move the model into VRAM (on CUDA devices), and
|
||||
return the model in a form suitable for passing to torch.
|
||||
Example:
|
||||
```
|
||||
@@ -40,13 +37,9 @@ class LoadedModelWithoutConfig:
|
||||
```
|
||||
|
||||
2. Newer API (recommended):
|
||||
Call the LoadedModel's `model_on_device()` method in a
|
||||
context. It returns a tuple consisting of a copy of
|
||||
the model's state dict in CPU RAM followed by a copy
|
||||
of the model in VRAM. The state dict is provided to allow
|
||||
LoRAs and other model patchers to return the model to
|
||||
its unpatched state without expensive copy and restore
|
||||
operations.
|
||||
Call the LoadedModel's `model_on_device()` method in a context. It returns a tuple consisting of a copy of the
|
||||
model's state dict in CPU RAM followed by a copy of the model in VRAM. The state dict is provided to allow LoRAs and
|
||||
other model patchers to return the model to its unpatched state without expensive copy and restore operations.
|
||||
|
||||
Example:
|
||||
```
|
||||
@@ -55,43 +48,42 @@ class LoadedModelWithoutConfig:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The state_dict should be treated as a read-only object and
|
||||
never modified. Also be aware that some loadable models do
|
||||
not have a state_dict, in which case this value will be None.
|
||||
The state_dict should be treated as a read-only object and never modified. Also be aware that some loadable models
|
||||
do not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
def __init__(self, cache_record: CacheRecord, cache: ModelCache):
|
||||
self._cache_record = cache_record
|
||||
self._cache = cache
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@contextmanager
|
||||
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||
locked_model = self._locker.lock()
|
||||
self._cache.lock(self._cache_record.key)
|
||||
try:
|
||||
state_dict = self._locker.get_state_dict()
|
||||
yield (state_dict, locked_model)
|
||||
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
|
||||
finally:
|
||||
self._locker.unlock()
|
||||
self._cache.unlock(self._cache_record.key)
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without locking it."""
|
||||
return self._locker.model
|
||||
return self._cache_record.cached_model.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
def __init__(self, config: Optional[AnyModelConfig], cache_record: CacheRecord, cache: ModelCache):
|
||||
super().__init__(cache_record=cache_record, cache=cache)
|
||||
self.config = config
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
@@ -110,7 +102,7 @@ class ModelLoaderBase(ABC):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
pass
|
||||
@@ -138,6 +130,6 @@ class ModelLoaderBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
pass
|
||||
|
||||
@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -28,13 +29,14 @@ class ModelLoader(ModelLoaderBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
self._torch_device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
@@ -53,11 +55,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
||||
|
||||
with skip_torch_weight_init():
|
||||
locker = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, _locker=locker)
|
||||
cache_record = self._load_and_cache(model_config, submodel_type)
|
||||
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
def ram_cache(self) -> ModelCache:
|
||||
"""Return the ram cache associated with this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@@ -65,10 +67,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_base = self._app_config.models_path
|
||||
return (model_base / config.path).resolve()
|
||||
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
||||
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord:
|
||||
stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")])
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type, stats_name=stats_name)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@@ -77,16 +79,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
loaded_model = self._load_model(config, submodel_type)
|
||||
|
||||
self._ram_cache.put(
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
get_model_cache_key(config.key, submodel_type),
|
||||
model=loaded_model,
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=stats_name,
|
||||
)
|
||||
return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
|
||||
from .model_cache_default import ModelCache # noqa F401
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord:
|
||||
"""A class that represents a model in the model cache."""
|
||||
|
||||
# Cache key.
|
||||
key: str
|
||||
# Model in memory.
|
||||
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
|
||||
# If locks > 0, the model is actively being used, so we should do our best to keep it on the compute device.
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def is_locked(self) -> bool:
|
||||
return self._locks > 0
|
||||
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -0,0 +1,81 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class CachedModelOnlyFullLoad:
|
||||
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
self._total_bytes = total_bytes
|
||||
self._is_in_vram = False
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better and implement it.
|
||||
return None
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._is_in_vram:
|
||||
return self._total_bytes
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_in_vram(self) -> bool:
|
||||
"""Return true if the model is currently in VRAM."""
|
||||
return self._is_in_vram
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM (if supported by the model).
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
if self._is_in_vram:
|
||||
# Already in VRAM.
|
||||
return 0
|
||||
|
||||
if not hasattr(self._model, "to"):
|
||||
# Model doesn't support moving to a device.
|
||||
return 0
|
||||
|
||||
self._model.to(self._compute_device)
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
if not self._is_in_vram:
|
||||
# Already in RAM.
|
||||
return 0
|
||||
|
||||
self._model.to("cpu")
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_function_autocast_context import (
|
||||
add_autocast_to_module_forward,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# Monkey-patch the model to add autocasting to the model's forward method.
|
||||
add_autocast_to_module_forward(model, compute_device)
|
||||
|
||||
# TODO(ryand): Manage a read-only CPU copy of the model state dict.
|
||||
# TODO(ryand): Add memoization for total_bytes and cur_vram_bytes?
|
||||
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._model.parameters())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in self._model.parameters() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM."""
|
||||
return self.partial_load_to_vram(self.total_bytes())
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
# Skip parameters that are already on the compute device.
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
# Check the size of the parameter.
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
continue
|
||||
|
||||
# Copy the parameter to the compute device.
|
||||
# We use the 'overwrite' strategy from torch.nn.Module._apply().
|
||||
# TODO(ryand): For some edge cases (e.g. quantized models?), we may need to support other strategies (e.g.
|
||||
# swap).
|
||||
assert isinstance(param, torch.nn.Parameter)
|
||||
assert param.is_leaf
|
||||
out_param = torch.nn.Parameter(param.to(self._compute_device, copy=True), requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
# We did not port the param.grad handling from torch.nn.Module._apply(), because we do not expect to be
|
||||
# handling gradients. We assert that this assumption is true.
|
||||
assert param.grad is None
|
||||
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
# TODO(ryand): Iterate over buffers too?
|
||||
for key, param in self._model.named_parameters():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
|
||||
if param.device.type != self._compute_device.type:
|
||||
continue
|
||||
|
||||
# Create a new parameter, but inject the existing CPU tensor into it.
|
||||
out_param = torch.nn.Parameter(self._cpu_state_dict[key], requires_grad=param.requires_grad)
|
||||
set_nested_attr(self._model, key, out_param)
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
return vram_bytes_freed
|
||||
534
invokeai/backend/model_manager/load/model_cache/model_cache.py
Normal file
534
invokeai/backend/model_manager/load/model_cache/model_cache.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import gc
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
# TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels.
|
||||
def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
"""Get the cache key for a model based on the optional submodel type."""
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
# TODO(ryand): Think about what lazy_offloading should mean in the new model cache.
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = PrefixedLoggerAdapter(
|
||||
logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE"
|
||||
)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collecting cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
"""Add a model to the cache."""
|
||||
if key in self._cached_models:
|
||||
self._logger.debug(
|
||||
f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary."
|
||||
)
|
||||
return
|
||||
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
# Wrap model.
|
||||
if isinstance(model, torch.nn.Module):
|
||||
wrapped_model = CachedModelWithPartialLoad(model, self._execution_device)
|
||||
else:
|
||||
wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size)
|
||||
|
||||
# running_on_cpu = self._execution_device == torch.device("cpu")
|
||||
# state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, cached_model=wrapped_model)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
self._logger.debug(
|
||||
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)"
|
||||
)
|
||||
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
:param key: Model key
|
||||
:param stats_name: A human-readable id for the model for the purposes of stats reporting.
|
||||
|
||||
Raises IndexError if the model is not in the cache.
|
||||
"""
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
self._logger.debug(f"Cache miss: {key}")
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes()
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
self._cache_stack = [k for k in self._cache_stack if k != key]
|
||||
self._cache_stack.append(key)
|
||||
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
return cache_entry
|
||||
|
||||
def lock(self, key: str) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.lock()
|
||||
|
||||
self._logger.debug(f"Locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
try:
|
||||
self._load_locked_model(cache_entry)
|
||||
self._logger.debug(
|
||||
f"Finished locking model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})"
|
||||
)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
cache_entry.unlock()
|
||||
raise
|
||||
|
||||
self._log_cache_state()
|
||||
|
||||
def unlock(self, key: str) -> None:
|
||||
"""Unlock a model."""
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry.unlock()
|
||||
self._logger.debug(f"Unlocked model {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
|
||||
def _load_locked_model(self, cache_entry: CacheRecord) -> None:
|
||||
"""Helper function for self.lock(). Loads a locked model into VRAM."""
|
||||
vram_available = self._get_vram_available()
|
||||
|
||||
# The amount of additional VRAM that will be used if we fully load the model into VRAM.
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
model_total_bytes = cache_entry.cached_model.total_bytes()
|
||||
model_vram_needed = model_total_bytes - model_cur_vram_bytes
|
||||
|
||||
self._logger.debug(
|
||||
f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Make room for the model in VRAM.
|
||||
# 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully.
|
||||
# 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as
|
||||
# possible.
|
||||
vram_bytes_freed = self._offload_unlocked_models(model_vram_needed)
|
||||
self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB")
|
||||
|
||||
# Check the updated vram_available after offloading.
|
||||
vram_available = self._get_vram_available()
|
||||
self._logger.debug(
|
||||
f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
# Move as much of the model as possible into VRAM.
|
||||
model_bytes_loaded = 0
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
model_bytes_loaded = cache_entry.cached_model.partial_load_to_vram(vram_available)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
# Partial load is not supported, so we have not choice but to try and fit it all into VRAM.
|
||||
model_bytes_loaded = cache_entry.cached_model.full_load_to_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
|
||||
model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
|
||||
vram_available = self._get_vram_available()
|
||||
self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ")
|
||||
self._logger.debug(
|
||||
f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}"
|
||||
)
|
||||
|
||||
def _get_vram_available(self) -> int:
|
||||
"""Get the amount of VRAM available in the cache."""
|
||||
return int(self._max_vram_cache_size * GB) - self._get_vram_in_use()
|
||||
|
||||
def _get_vram_in_use(self) -> int:
|
||||
"""Get the amount of VRAM currently in use."""
|
||||
return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _get_ram_available(self) -> int:
|
||||
"""Get the amount of RAM available in the cache."""
|
||||
return int(self._max_cache_size * GB) - self._get_ram_in_use()
|
||||
|
||||
def _get_ram_in_use(self) -> int:
|
||||
"""Get the amount of RAM currently in use."""
|
||||
return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values())
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str:
|
||||
"""Helper function for preparing a VRAM state log string."""
|
||||
model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0
|
||||
return (
|
||||
f"model_total={model_total_bytes/MB:.0f} MB, "
|
||||
+ f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), "
|
||||
+ f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, "
|
||||
+ f"vram_available={(vram_available/MB):.0f} MB, "
|
||||
)
|
||||
|
||||
def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int:
|
||||
"""Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are
|
||||
offloaded. Of course, locked models are not offloaded.
|
||||
|
||||
Returns:
|
||||
int: The number of bytes freed.
|
||||
"""
|
||||
self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.")
|
||||
vram_bytes_freed = 0
|
||||
# TODO(ryand): Give more thought to the offloading policy used here.
|
||||
cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes())
|
||||
for cache_entry in cache_entries_increasing_size:
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
if cache_entry.is_locked:
|
||||
continue
|
||||
|
||||
if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad):
|
||||
cache_entry_bytes_freed = cache_entry.cached_model.partial_unload_from_vram(
|
||||
vram_bytes_to_free - vram_bytes_freed
|
||||
)
|
||||
elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore
|
||||
cache_entry_bytes_freed = cache_entry.cached_model.full_unload_from_vram()
|
||||
else:
|
||||
raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}")
|
||||
if cache_entry_bytes_freed > 0:
|
||||
self._logger.debug(
|
||||
f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."
|
||||
)
|
||||
vram_bytes_freed += cache_entry_bytes_freed
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
return vram_bytes_freed
|
||||
|
||||
# def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||
# """Move model into the indicated device.
|
||||
|
||||
# :param cache_entry: The CacheRecord for the model
|
||||
# :param target_device: The torch.device to move the model into
|
||||
|
||||
# May raise a torch.cuda.OutOfMemoryError
|
||||
# """
|
||||
# self._logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
# source_device = cache_entry.device
|
||||
|
||||
# # Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# # This would need to be revised to support multi-GPU.
|
||||
# if torch.device(source_device).type == torch.device(target_device).type:
|
||||
# return
|
||||
|
||||
# # Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
# if not hasattr(cache_entry.model, "to"):
|
||||
# return
|
||||
|
||||
# # This roundabout method for moving the model around is done to avoid
|
||||
# # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# # When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# # RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# # This operation is slightly faster than running `to()` on the whole model.
|
||||
# #
|
||||
# # When the model needs to be removed from VRAM we simply delete the copy
|
||||
# # of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# # in RAM into the model. So this operation is very fast.
|
||||
# start_model_to_time = time.time()
|
||||
# snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
# try:
|
||||
# if cache_entry.state_dict is not None:
|
||||
# assert hasattr(cache_entry.model, "load_state_dict")
|
||||
# if target_device == self._storage_device:
|
||||
# cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
# else:
|
||||
# new_dict: Dict[str, torch.Tensor] = {}
|
||||
# for k, v in cache_entry.state_dict.items():
|
||||
# new_dict[k] = v.to(target_device, copy=True)
|
||||
# cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
# cache_entry.model.to(target_device)
|
||||
# cache_entry.device = target_device
|
||||
# except Exception as e: # blow away cache entry
|
||||
# self._delete_cache_entry(cache_entry)
|
||||
# raise e
|
||||
|
||||
# snapshot_after = self._capture_memory_snapshot()
|
||||
# end_model_to_time = time.time()
|
||||
# self._logger.debug(
|
||||
# f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
# f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
# f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# snapshot_before is not None
|
||||
# and snapshot_after is not None
|
||||
# and snapshot_before.vram is not None
|
||||
# and snapshot_after.vram is not None
|
||||
# ):
|
||||
# vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# # If the estimated model size does not match the change in VRAM, log a warning.
|
||||
# if not math.isclose(
|
||||
# vram_change,
|
||||
# cache_entry.size,
|
||||
# rel_tol=0.1,
|
||||
# abs_tol=10 * MB,
|
||||
# ):
|
||||
# self._logger.debug(
|
||||
# f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
# f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
# " estimated size may be incorrect. Estimated model size:"
|
||||
# f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
# f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
# )
|
||||
|
||||
def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True):
|
||||
ram_size_bytes = self._max_cache_size * GB
|
||||
ram_in_use_bytes = self._get_ram_in_use()
|
||||
ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
ram_available_bytes = self._get_ram_available()
|
||||
ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0
|
||||
|
||||
vram_size_bytes = self._max_vram_cache_size * GB
|
||||
vram_in_use_bytes = self._get_vram_in_use()
|
||||
vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
vram_available_bytes = self._get_vram_available()
|
||||
vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0
|
||||
|
||||
log = f"{title}\n"
|
||||
|
||||
log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n"
|
||||
log += log_format.format(
|
||||
f"Storage Device ({self._storage_device.type})",
|
||||
ram_size_bytes / MB,
|
||||
ram_in_use_bytes / MB,
|
||||
ram_in_use_bytes_percent,
|
||||
ram_available_bytes / MB,
|
||||
ram_available_bytes_percent,
|
||||
)
|
||||
log += log_format.format(
|
||||
f"Compute Device ({self._execution_device.type})",
|
||||
vram_size_bytes / MB,
|
||||
vram_in_use_bytes / MB,
|
||||
vram_in_use_bytes_percent,
|
||||
vram_available_bytes / MB,
|
||||
vram_available_bytes_percent,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
log += " {:<30} {} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB)
|
||||
log += " {:<30} {}\n".format("Total models:", len(self._cached_models))
|
||||
|
||||
if include_entry_details and len(self._cached_models) > 0:
|
||||
log += " Models:\n"
|
||||
log_format = (
|
||||
" {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n"
|
||||
)
|
||||
for cache_record in self._cached_models.values():
|
||||
total_bytes = cache_record.cached_model.total_bytes()
|
||||
cur_vram_bytes = cache_record.cached_model.cur_vram_bytes()
|
||||
cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
cur_ram_bytes = total_bytes - cur_vram_bytes
|
||||
cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0
|
||||
|
||||
log += log_format.format(
|
||||
f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):",
|
||||
total_bytes / MB,
|
||||
cur_vram_bytes / MB,
|
||||
cur_vram_bytes_percent,
|
||||
cur_ram_bytes / MB,
|
||||
cur_ram_bytes_percent,
|
||||
cache_record.is_locked,
|
||||
)
|
||||
|
||||
self._logger.debug(log)
|
||||
|
||||
def make_room(self, bytes_needed: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="Before dropping models:")
|
||||
|
||||
ram_bytes_available = self._get_ram_available()
|
||||
ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available)
|
||||
|
||||
ram_bytes_freed = 0
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
if not cache_entry.is_locked:
|
||||
ram_bytes_freed += cache_entry.cached_model.total_bytes()
|
||||
self._logger.debug(
|
||||
f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB."
|
||||
)
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
models_cleared += 1
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="After dropping models:")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
"""Base class for the model locker used by the loader."""
|
||||
|
||||
@abstractmethod
|
||||
def lock(self) -> AnyModel:
|
||||
"""Lock the contained model and move it into VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unlock(self) -> None:
|
||||
"""Unlock the contained model, and remove it from VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model."""
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Virtual base class for RAM model cache."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the maximum size the RAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
@abstractmethod
|
||||
def max_vram_cache_size(self, value: float) -> float:
|
||||
"""Set the maximum size the VRAM cache can grow to."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
@stats.setter
|
||||
@abstractmethod
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log debugging information on CUDA usage."""
|
||||
pass
|
||||
@@ -1,426 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
# TODO: Add Stalker's proper name to copyright
|
||||
""" """
|
||||
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
CacheStats,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLocker
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# Size of a GB in bytes.
|
||||
GB = 2**30
|
||||
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""A cache for managing models in memory.
|
||||
|
||||
The cache is based on two levels of model storage:
|
||||
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
|
||||
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
|
||||
|
||||
The model cache is based on the following assumptions:
|
||||
- storage_device_mem_size > execution_device_mem_size
|
||||
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
|
||||
|
||||
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
|
||||
the execution_device.
|
||||
|
||||
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
|
||||
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
|
||||
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
|
||||
|
||||
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
|
||||
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
|
||||
configuration.
|
||||
|
||||
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
|
||||
the context, and unload outside the context.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
|
||||
do_something_on_gpu(SD1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float,
|
||||
max_vram_cache_size: float,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
lazy_offloading: bool = True,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the storage_device cache in GBs.
|
||||
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
"""Return the cap on cache size."""
|
||||
return self._max_cache_size
|
||||
|
||||
@max_cache_size.setter
|
||||
def max_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on cache size."""
|
||||
self._max_cache_size = value
|
||||
|
||||
@property
|
||||
def max_vram_cache_size(self) -> float:
|
||||
"""Return the cap on vram cache size."""
|
||||
return self._max_vram_cache_size
|
||||
|
||||
@max_vram_cache_size.setter
|
||||
def max_vram_cache_size(self, value: float) -> None:
|
||||
"""Set the cap on vram cache size."""
|
||||
self._max_vram_cache_size = value
|
||||
|
||||
@property
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
"""Return collected CacheStats object."""
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collectin cache statistics."""
|
||||
self._stats = stats
|
||||
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
total = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
total += cache_record.size
|
||||
return total
|
||||
|
||||
def put(
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
running_on_cpu = self.execution_device == torch.device("cpu")
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GB)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||
if submodel_type:
|
||||
return f"{model_key}:{submodel_type.value}"
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload models from the execution_device to make room for size_required.
|
||||
|
||||
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
|
||||
"""
|
||||
reserved = self._max_vram_cache_size * GB
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GB):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
|
||||
ram = "%4.2fG" % (self.cache_size() / GB)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size.
|
||||
|
||||
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
bytes_needed = size
|
||||
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GB):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Base class and implementation of a class that moves models in and out of VRAM.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import (
|
||||
CacheRecord,
|
||||
ModelCacheBase,
|
||||
ModelLockerBase,
|
||||
)
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
|
||||
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
|
||||
"""
|
||||
Initialize the model locker.
|
||||
|
||||
:param cache: The ModelCache object
|
||||
:param cache_entry: The entry in the model cache
|
||||
"""
|
||||
self._cache = cache
|
||||
self._cache_entry = cache_entry
|
||||
|
||||
@property
|
||||
def model(self) -> AnyModel:
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
||||
def add_autocast_to_module_forward(m: torch.nn.Module, to_device: torch.device):
|
||||
"""Monkey-patch m.forward(...) with a new forward(...) method that activates device autocasting for its duration."""
|
||||
old_forward = m.forward
|
||||
|
||||
def new_forward(*args: Any, **kwargs: Any):
|
||||
with TorchFunctionAutocastDeviceContext(to_device):
|
||||
return old_forward(*args, **kwargs)
|
||||
|
||||
m.forward = new_forward
|
||||
|
||||
|
||||
def _cast_to_device_and_run(
|
||||
func: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], to_device: torch.device
|
||||
):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchFunctionAutocastDeviceContext(TorchFunctionMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_function__(
|
||||
self, func: Callable[..., Any], types, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None
|
||||
):
|
||||
return _cast_to_device_and_run(func, args, kwargs or {}, self._to_device)
|
||||
@@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.float32
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
|
||||
return model
|
||||
|
||||
@@ -128,9 +136,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -172,9 +180,9 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -26,7 +26,7 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class LoRALoader(ModelLoader):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
ram_cache: ModelCache,
|
||||
):
|
||||
"""Initialize the loader."""
|
||||
super().__init__(app_config, logger, ram_cache)
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion35, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class FluxCheckpointModel(ModelLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Transformer:
|
||||
return self._load_from_singlefile(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||
)
|
||||
|
||||
def _load_from_singlefile(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
model_path = Path(config.path)
|
||||
|
||||
# model = Flux(params[config.config_path])
|
||||
# sd = load_file(model_path)
|
||||
# if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
# sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
# new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
|
||||
# self._ram_cache.make_room(new_sd_size)
|
||||
# for k in sd.keys():
|
||||
# # We need to cast to bfloat16 due to it being the only currently supported dtype for inference
|
||||
# sd[k] = sd[k].to(torch.bfloat16)
|
||||
# model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@@ -25,6 +25,7 @@ from invokeai.backend.model_manager.config import (
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -42,6 +43,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@@ -51,13 +53,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
@@ -117,8 +112,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
load_class = load_classes[config.base][config.variant]
|
||||
except KeyError as e:
|
||||
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
|
||||
prediction_type = config.prediction_type.value
|
||||
upcast_attention = config.upcast_attention
|
||||
|
||||
# Without SilenceWarnings we get log messages like this:
|
||||
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
|
||||
@@ -129,13 +122,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
# ['text_model.embeddings.position_ids']
|
||||
|
||||
with SilenceWarnings():
|
||||
pipeline = load_class.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
pipeline = load_class.from_single_file(config.path, torch_dtype=self._torch_dtype)
|
||||
|
||||
if not submodel_type:
|
||||
return pipeline
|
||||
@@ -146,5 +133,5 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel)
|
||||
return getattr(pipeline, submodel_type.value)
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
@@ -33,11 +34,17 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubmodelDefinition,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
get_clip_variant_type,
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.sd3.sd3_state_dict_utils import is_sd3_checkpoint
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -113,6 +120,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
@@ -121,11 +129,14 @@ class ModelProbe(object):
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
@@ -172,7 +183,10 @@ class ModelProbe(object):
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
|
||||
fields["variant"] = (
|
||||
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
|
||||
)
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
@@ -219,6 +233,10 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
get_submodels = getattr(probe, "get_submodels", None)
|
||||
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@@ -243,11 +261,6 @@ class ModelProbe(object):
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
# The following prefixes appear when multiple models have been bundled together in a single file (I
|
||||
# believe the format originated in ComfyUI).
|
||||
# first_stage_model = VAE
|
||||
# cond_stage_model = Text Encoder
|
||||
# model.diffusion_model = UNet / Transformer
|
||||
"cond_stage_model.",
|
||||
"first_stage_model.",
|
||||
"model.diffusion_model.",
|
||||
@@ -404,9 +417,6 @@ class ModelProbe(object):
|
||||
# is used rather than attempting to support flux with separate model types and format
|
||||
# If changed in the future, please fix me
|
||||
config_file = "flux-schnell"
|
||||
elif base_type == BaseModelType.StableDiffusion35:
|
||||
# TODO(ryand): Think about what to do here.
|
||||
config_file = "sd3.5-large"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
@@ -459,7 +469,7 @@ class ModelProbe(object):
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
@@ -472,8 +482,10 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"normal": "normalbae_image_processor",
|
||||
"sketch": "pidi_image_processor",
|
||||
"scribble": "lineart_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"lineart anime": "lineart_anime_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"soft": "hed_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"hed": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
@@ -526,7 +538,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
base_type = self.get_base_type()
|
||||
if model_type != ModelType.Main or base_type in (BaseModelType.Flux, BaseModelType.StableDiffusion35):
|
||||
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
@@ -551,10 +563,6 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
if is_sd3_checkpoint(state_dict):
|
||||
return BaseModelType.StableDiffusion35
|
||||
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
@@ -760,18 +768,33 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
config_path = self.model_path / "unet" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
# Handle pipelines with a transformer (i.e. SD3).
|
||||
config_path = self.model_path / "transformer" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
transformer_conf = json.load(file)
|
||||
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
@@ -783,6 +806,23 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
|
||||
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
|
||||
submodels: Dict[SubModelType, SubmodelDefinition] = {}
|
||||
for key, value in config.items():
|
||||
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
|
||||
continue
|
||||
model_loader = str(value[1])
|
||||
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
|
||||
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=(self.model_path / key).resolve().as_posix(),
|
||||
model_type=model_type,
|
||||
variant=variant_func and variant_func((self.model_path / key).as_posix()),
|
||||
)
|
||||
|
||||
return submodels
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
|
||||
@@ -13,6 +13,9 @@ class StarterModelWithoutDependencies(BaseModel):
|
||||
type: ModelType
|
||||
format: Optional[ModelFormat] = None
|
||||
is_installed: bool = False
|
||||
# allows us to track what models a user has installed across name changes within starter models
|
||||
# if you update a starter model name, please add the old one to this list for that starter model
|
||||
previous_names: list[str] = []
|
||||
|
||||
|
||||
class StarterModel(StarterModelWithoutDependencies):
|
||||
@@ -137,6 +140,22 @@ flux_dev = StarterModel(
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-medium",
|
||||
description="Medium SD3.5 Model: ~15GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
sd35_large = StarterModel(
|
||||
name="SD3.5 Large",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-large",
|
||||
description="Large SD3.5 Model: ~19G",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
cyberrealistic_sd1 = StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -243,42 +262,46 @@ easy_neg_sd1 = StarterModel(
|
||||
# endregion
|
||||
# region IP Adapter
|
||||
ip_adapter_sd1 = StarterModel(
|
||||
name="IP Adapter",
|
||||
name="Standard Reference (IP Adapter)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
|
||||
description="IP-Adapter for SD 1.5 models",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
previous_names=["IP Adapter"],
|
||||
)
|
||||
ip_adapter_plus_sd1 = StarterModel(
|
||||
name="IP Adapter Plus",
|
||||
name="Precise Reference (IP Adapter Plus)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
|
||||
description="Refined IP-Adapter for SD 1.5 models",
|
||||
description="References images with a higher degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
previous_names=["IP Adapter Plus"],
|
||||
)
|
||||
ip_adapter_plus_face_sd1 = StarterModel(
|
||||
name="IP Adapter Plus Face",
|
||||
name="Face Reference (IP Adapter Plus Face)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
|
||||
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
||||
description="References images with a higher degree of precision, adapted for faces",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
previous_names=["IP Adapter Plus Face"],
|
||||
)
|
||||
ip_adapter_sdxl = StarterModel(
|
||||
name="IP Adapter SDXL",
|
||||
name="Standard Reference (IP Adapter ViT-H)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
|
||||
description="IP-Adapter for SDXL models",
|
||||
description="References images with a higher degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
previous_names=["IP Adapter SDXL"],
|
||||
)
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="XLabs FLUX IP-Adapter",
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
|
||||
description="FLUX IP-Adapter",
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/resolve/main/ip_adapter.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
)
|
||||
@@ -299,157 +322,162 @@ qr_code_cnet_sdxl = StarterModel(
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
canny_sd1 = StarterModel(
|
||||
name="canny",
|
||||
name="Hard Edge Detection (canny)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
description="Uses detected edges in the image to control composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["canny"],
|
||||
)
|
||||
inpaint_cnet_sd1 = StarterModel(
|
||||
name="inpaint",
|
||||
name="Inpainting",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["inpaint"],
|
||||
)
|
||||
mlsd_sd1 = StarterModel(
|
||||
name="mlsd",
|
||||
name="Line Drawing (mlsd)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
description="Uses straight line detection for controlling the generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["mlsd"],
|
||||
)
|
||||
depth_sd1 = StarterModel(
|
||||
name="depth",
|
||||
name="Depth Map",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["depth"],
|
||||
)
|
||||
normal_bae_sd1 = StarterModel(
|
||||
name="normal_bae",
|
||||
name="Lighting Detection (Normals)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
description="Uses detected lighting information to guide the lighting of the composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["normal_bae"],
|
||||
)
|
||||
seg_sd1 = StarterModel(
|
||||
name="seg",
|
||||
name="Segmentation Map",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
description="Uses segmentation maps to guide the structure of the composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["seg"],
|
||||
)
|
||||
lineart_sd1 = StarterModel(
|
||||
name="lineart",
|
||||
name="Lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
description="Uses lineart detection to guide the lighting of the composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["lineart"],
|
||||
)
|
||||
lineart_anime_sd1 = StarterModel(
|
||||
name="lineart_anime",
|
||||
name="Lineart Anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
description="Uses anime lineart detection to guide the lighting of the composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["lineart_anime"],
|
||||
)
|
||||
openpose_sd1 = StarterModel(
|
||||
name="openpose",
|
||||
name="Pose Detection (openpose)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
description="Uses pose information to control the pose of human characters in the generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["openpose"],
|
||||
)
|
||||
scribble_sd1 = StarterModel(
|
||||
name="scribble",
|
||||
name="Contour Detection (scribble)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
description="Uses edges, contours, or line art in the image to control composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["scribble"],
|
||||
)
|
||||
softedge_sd1 = StarterModel(
|
||||
name="softedge",
|
||||
name="Soft Edge Detection (softedge)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
description="Uses a soft edge detection map to control composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["softedge"],
|
||||
)
|
||||
shuffle_sd1 = StarterModel(
|
||||
name="shuffle",
|
||||
name="Remix (shuffle)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["shuffle"],
|
||||
)
|
||||
tile_sd1 = StarterModel(
|
||||
name="tile",
|
||||
name="Tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
ip2p_sd1 = StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
description="Uses image data to replicate exact colors/structure in the resulting generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["tile"],
|
||||
)
|
||||
canny_sdxl = StarterModel(
|
||||
name="canny-sdxl",
|
||||
name="Hard Edge Detection (canny)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
description="Uses detected edges in the image to control composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["canny-sdxl"],
|
||||
)
|
||||
depth_sdxl = StarterModel(
|
||||
name="depth-sdxl",
|
||||
name="Depth Map",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["depth-sdxl"],
|
||||
)
|
||||
softedge_sdxl = StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
name="Soft Edge Detection (softedge)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
depth_zoe_16_sdxl = StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
depth_zoe_32_sdxl = StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
description="Uses a soft edge detection map to control composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["softedge-dexined-sdxl"],
|
||||
)
|
||||
openpose_sdxl = StarterModel(
|
||||
name="openpose-sdxl",
|
||||
name="Pose Detection (openpose)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
description="Uses pose information to control the pose of human characters in the generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["openpose-sdxl", "controlnet-openpose-sdxl"],
|
||||
)
|
||||
scribble_sdxl = StarterModel(
|
||||
name="scribble-sdxl",
|
||||
name="Contour Detection (scribble)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
description="Uses edges, contours, or line art in the image to control composition.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["scribble-sdxl", "controlnet-scribble-sdxl"],
|
||||
)
|
||||
tile_sdxl = StarterModel(
|
||||
name="tile-sdxl",
|
||||
name="Tile",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
description="Uses image data to replicate exact colors/structure in the resulting generation.",
|
||||
type=ModelType.ControlNet,
|
||||
previous_names=["tile-sdxl"],
|
||||
)
|
||||
union_cnet_sdxl = StarterModel(
|
||||
name="Multi-Guidance Detection (Union Pro)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/Xinsir-SDXL_Controlnet_Union",
|
||||
description="A unified ControlNet for SDXL model that supports 10+ control types",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
union_cnet_flux = StarterModel(
|
||||
@@ -462,60 +490,52 @@ union_cnet_flux = StarterModel(
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
t2i_canny_sd1 = StarterModel(
|
||||
name="canny-sd15",
|
||||
name="Hard Edge Detection (canny)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_canny_sd15v2",
|
||||
description="T2I Adapter weights trained on sd-1.5 with canny conditioning.",
|
||||
description="Uses detected edges in the image to control composition",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["canny-sd15"],
|
||||
)
|
||||
t2i_sketch_sd1 = StarterModel(
|
||||
name="sketch-sd15",
|
||||
name="Sketch",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_sketch_sd15v2",
|
||||
description="T2I Adapter weights trained on sd-1.5 with sketch conditioning.",
|
||||
description="Uses a sketch to control composition",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["sketch-sd15"],
|
||||
)
|
||||
t2i_depth_sd1 = StarterModel(
|
||||
name="depth-sd15",
|
||||
name="Depth Map",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_depth_sd15v2",
|
||||
description="T2I Adapter weights trained on sd-1.5 with depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
)
|
||||
t2i_zoe_depth_sd1 = StarterModel(
|
||||
name="zoedepth-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_zoedepth_sd15v1",
|
||||
description="T2I Adapter weights trained on sd-1.5 with zoe depth conditioning.",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["depth-sd15"],
|
||||
)
|
||||
t2i_canny_sdxl = StarterModel(
|
||||
name="canny-sdxl",
|
||||
name="Hard Edge Detection (canny)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with canny conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
)
|
||||
t2i_zoe_depth_sdxl = StarterModel(
|
||||
name="zoedepth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with zoe depth conditioning.",
|
||||
description="Uses detected edges in the image to control composition",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["canny-sdxl"],
|
||||
)
|
||||
t2i_lineart_sdxl = StarterModel(
|
||||
name="lineart-sdxl",
|
||||
name="Lineart",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with lineart conditioning.",
|
||||
description="Uses lineart detection to guide the lighting of the composition.",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["lineart-sdxl"],
|
||||
)
|
||||
t2i_sketch_sdxl = StarterModel(
|
||||
name="sketch-sdxl",
|
||||
name="Sketch",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with sketch conditioning.",
|
||||
description="Uses a sketch to control composition",
|
||||
type=ModelType.T2IAdapter,
|
||||
previous_names=["sketch-sdxl"],
|
||||
)
|
||||
# endregion
|
||||
# region SpandrelImageToImage
|
||||
@@ -565,6 +585,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
flux_dev,
|
||||
sd35_medium,
|
||||
sd35_large,
|
||||
cyberrealistic_sd1,
|
||||
rev_animated_sd1,
|
||||
dreamshaper_8_sd1,
|
||||
@@ -600,22 +622,18 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
softedge_sd1,
|
||||
shuffle_sd1,
|
||||
tile_sd1,
|
||||
ip2p_sd1,
|
||||
canny_sdxl,
|
||||
depth_sdxl,
|
||||
softedge_sdxl,
|
||||
depth_zoe_16_sdxl,
|
||||
depth_zoe_32_sdxl,
|
||||
openpose_sdxl,
|
||||
scribble_sdxl,
|
||||
tile_sdxl,
|
||||
union_cnet_sdxl,
|
||||
union_cnet_flux,
|
||||
t2i_canny_sd1,
|
||||
t2i_sketch_sd1,
|
||||
t2i_depth_sd1,
|
||||
t2i_zoe_depth_sd1,
|
||||
t2i_canny_sdxl,
|
||||
t2i_zoe_depth_sdxl,
|
||||
t2i_lineart_sdxl,
|
||||
t2i_sketch_sdxl,
|
||||
realesrgan_x4,
|
||||
@@ -646,7 +664,6 @@ sd1_bundle: list[StarterModel] = [
|
||||
softedge_sd1,
|
||||
shuffle_sd1,
|
||||
tile_sd1,
|
||||
ip2p_sd1,
|
||||
swinir,
|
||||
]
|
||||
|
||||
@@ -657,8 +674,6 @@ sdxl_bundle: list[StarterModel] = [
|
||||
canny_sdxl,
|
||||
depth_sdxl,
|
||||
softedge_sdxl,
|
||||
depth_zoe_16_sdxl,
|
||||
depth_zoe_32_sdxl,
|
||||
openpose_sdxl,
|
||||
scribble_sdxl,
|
||||
tile_sdxl,
|
||||
|
||||
@@ -8,6 +8,7 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, torch.Tensor]:
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
path_str = path.as_posix() if isinstance(path, Path) else path
|
||||
@@ -54,7 +55,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
if str(path).endswith(".gguf"):
|
||||
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
|
||||
@@ -165,3 +166,25 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
||||
try:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
config_path = path / "text_encoder" / "config.json"
|
||||
if not config_path.exists():
|
||||
return ClipVariantType.L
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return ClipVariantType.L
|
||||
except Exception:
|
||||
return ClipVariantType.L
|
||||
|
||||
@@ -85,6 +85,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
safetensors_detected = False
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
@@ -119,19 +120,27 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if safetensors_detected and path.suffix == ".bin":
|
||||
continue
|
||||
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
safetensors_detected = True
|
||||
if parent in subfolder_weights:
|
||||
subfolder_weights[parent] = [sfc for sfc in subfolder_weights[parent] if sfc.path.suffix != ".bin"]
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
if (
|
||||
variant is not ModelRepoVariant.Default
|
||||
and candidate_variant_label
|
||||
and candidate_variant_label.startswith(f".{variant.value}")
|
||||
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
@@ -146,7 +155,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
@@ -162,7 +171,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
|
||||
match = re.match(pattern, highest_score_candidate.path.as_posix())
|
||||
if match:
|
||||
for candidate in candidate_list:
|
||||
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
|
||||
match.group(2)
|
||||
):
|
||||
result.add(candidate.path)
|
||||
else:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
|
||||
0
invokeai/backend/sd3/extensions/__init__.py
Normal file
0
invokeai/backend/sd3/extensions/__init__.py
Normal file
58
invokeai/backend/sd3/extensions/inpaint_extension.py
Normal file
58
invokeai/backend/sd3/extensions/inpaint_extension.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with SD3."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0).
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents.
|
||||
"""
|
||||
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4
|
||||
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:]
|
||||
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
|
||||
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
|
||||
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
|
||||
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
|
||||
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
|
||||
# threshold worked well.
|
||||
# We use a small epsilon to avoid any potential issues with floating point precision.
|
||||
eps = 1e-4
|
||||
mask_gradient_t_cutoff = 0.5
|
||||
if t_prev > mask_gradient_t_cutoff:
|
||||
# Early in the denoising process, use the inpaint mask as-is.
|
||||
return self._inpaint_mask
|
||||
else:
|
||||
# After the cut-off, promote all non-zero mask values to 1.0.
|
||||
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
|
||||
|
||||
return mask
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, t_prev: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
|
||||
mask = self._apply_mask_gradient_adjustment(t_prev)
|
||||
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
|
||||
@@ -1,891 +0,0 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/mmditx.py
|
||||
|
||||
|
||||
### This file contains impls for MM-DiT, the core model component of SD3
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from invokeai.backend.sd3.other_impls import Mlp, attention
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size, strict=False)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = torch.nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor | None, scale: torch.Tensor) -> torch.Tensor:
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
grid_size: int,
|
||||
cls_token: bool = False,
|
||||
extra_tokens: int = 0,
|
||||
scaling_factor: Optional[float] = None,
|
||||
offset: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(torch.nn.Module):
|
||||
"""Embeds scalar timesteps into vector representations."""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
frequency_embedding_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=t.device
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(torch.nn.Module):
|
||||
"""Embeds a flat vector of dimension input_dim"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
|
||||
class SelfAttention(torch.nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
attn_mode: str = "xformers",
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = torch.nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = torch.nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = torch.nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = torch.nn.Identity()
|
||||
self.ln_k = torch.nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
(q, k, v) = self.pre_attention(x)
|
||||
x = attention(q, k, v, self.num_heads)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine: bool = False,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (torch.nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = torch.nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = torch.nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = torch.nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(torch.nn.Module):
|
||||
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
|
||||
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: str = "xformers",
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if not rmsnorm:
|
||||
self.norm1 = torch.nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = torch.nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=torch.nn.GELU(approximate="tanh"),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert x is not None, "pre_attention called with None input"
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(
|
||||
6, dim=1
|
||||
)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(
|
||||
self,
|
||||
attn: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
gate_msa: torch.Tensor,
|
||||
shift_mlp: torch.Tensor,
|
||||
scale_mlp: torch.Tensor,
|
||||
gate_mlp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
def pre_attention_x(
|
||||
self, x: torch.Tensor, c: torch.Tensor
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
]:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return (
|
||||
qkv,
|
||||
qkv2,
|
||||
(
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
),
|
||||
)
|
||||
|
||||
def post_attention_x(
|
||||
self,
|
||||
attn: torch.Tensor,
|
||||
attn2: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
gate_msa: torch.Tensor,
|
||||
shift_mlp: torch.Tensor,
|
||||
scale_mlp: torch.Tensor,
|
||||
gate_mlp: torch.Tensor,
|
||||
gate_msa2: torch.Tensor,
|
||||
attn1_dropout: float = 0.0,
|
||||
):
|
||||
assert not self.pre_only
|
||||
if attn1_dropout > 0.0:
|
||||
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
||||
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
||||
else:
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + attn_
|
||||
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
||||
x = x + attn2_
|
||||
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
x = x + mlp_
|
||||
return x, (gate_msa, gate_msa2, gate_mlp, attn_, attn2_)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert not self.pre_only
|
||||
if self.x_block_self_attn:
|
||||
(q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
attn2 = attention(q2, k2, v2, self.attn2.num_heads)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
(q, k, v), intermediates = self.pre_attention(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(
|
||||
context: torch.Tensor, x: torch.Tensor, context_block: DismantledBlock, x_block: DismantledBlock, c: torch.Tensor
|
||||
):
|
||||
assert context is not None, "block_mixing called with None context"
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o: list[torch.Tensor] = []
|
||||
for t in range(3):
|
||||
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||
q, k, v = tuple(o)
|
||||
|
||||
attn = attention(q, k, v, x_block.attn.num_heads)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
else:
|
||||
context = None
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_q2, x_k2, x_v2 = x_qkv2
|
||||
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(torch.nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(
|
||||
*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
|
||||
|
||||
|
||||
class FinalLayer(torch.nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
total_out_channels: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = torch.nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.linear = (
|
||||
torch.nn.Linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if (total_out_channels is None)
|
||||
else torch.nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MMDiTX(torch.nn.Module):
|
||||
"""Diffusion model with a Transformer backbone."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int | None = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
register_length: int = 0,
|
||||
attn_mode: str = "torch",
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches: Optional[int] = None,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn_layers: Optional[List[int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if verbose:
|
||||
print(
|
||||
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.x_block_self_attn_layers = x_block_self_attn_layers or []
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=self.pos_embed_max_size is None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.context_embedder = torch.nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = torch.nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = torch.nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.joint_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=i == depth - 1,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
def cropped_pos_embed(self, hw: torch.Size) -> torch.Tensor:
|
||||
assert self.pos_embed_max_size is not None
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed: torch.Tensor = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
) # type: ignore Type checking does not correctly infer the type of the self.pos_embed buffer.
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, hw: Optional[torch.Size] = None) -> torch.Tensor:
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||
context if context is not None else torch.Tensor([]).type_as(x),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
for block in self.joint_blocks:
|
||||
context, x = block(context, x, c=c_mod)
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x
|
||||
@@ -1,795 +0,0 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/other_impls.py
|
||||
|
||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
#################################################################################################
|
||||
### Core/Utility
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
|
||||
class Mlp(torch.nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
||||
bias: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
if act_layer is None:
|
||||
act_layer = torch.nn.functional.gelu
|
||||
|
||||
self.fc1 = torch.nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer
|
||||
self.fc2 = torch.nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### CLIP
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
out = attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
ACTIVATIONS = {
|
||||
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
# self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(
|
||||
embed_dim,
|
||||
intermediate_size,
|
||||
embed_dim,
|
||||
act_layer=ACTIVATIONS[intermediate_activation],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
CLIPLayer(
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(
|
||||
num_layers,
|
||||
embed_dim,
|
||||
heads,
|
||||
intermediate_size,
|
||||
intermediate_activation,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
pooled_output = x[
|
||||
torch.arange(x.shape[0], device=x.device),
|
||||
input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
|
||||
]
|
||||
return x, i, pooled_output
|
||||
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
nesting_level = 0
|
||||
for char in string:
|
||||
if char == "(":
|
||||
if nesting_level == 0:
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
current_item = "("
|
||||
else:
|
||||
current_item = "("
|
||||
else:
|
||||
current_item += char
|
||||
nesting_level += 1
|
||||
elif char == ")":
|
||||
nesting_level -= 1
|
||||
if nesting_level == 0:
|
||||
result.append(current_item + ")")
|
||||
current_item = ""
|
||||
else:
|
||||
current_item += char
|
||||
else:
|
||||
current_item += char
|
||||
if current_item:
|
||||
result.append(current_item)
|
||||
return result
|
||||
|
||||
|
||||
def token_weights(string, current_weight):
|
||||
a = parse_parentheses(string)
|
||||
out = []
|
||||
for x in a:
|
||||
weight = current_weight
|
||||
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
|
||||
x = x[1:-1]
|
||||
xx = x.rfind(":")
|
||||
weight *= 1.1
|
||||
if xx > 0:
|
||||
try:
|
||||
weight = float(x[xx + 1 :])
|
||||
x = x[:xx]
|
||||
except:
|
||||
pass
|
||||
out += token_weights(x, weight)
|
||||
else:
|
||||
out += [(x, current_weight)]
|
||||
return out
|
||||
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
max_length=77,
|
||||
pad_with_end=True,
|
||||
tokenizer=None,
|
||||
has_start_token=True,
|
||||
pad_to_max_length=True,
|
||||
min_length=None,
|
||||
extra_padding_token=None,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
|
||||
empty = self.tokenizer("")["input_ids"]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
self.extra_padding_token = extra_padding_token
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
"""
|
||||
Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
|
||||
"""
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
|
||||
text = escape_important(text)
|
||||
parsed_weights = token_weights(text, 1.0)
|
||||
|
||||
# tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(" ")
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
# parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
|
||||
|
||||
# reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
# determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
# break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
# add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
# start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group])
|
||||
t_group = []
|
||||
|
||||
# pad extra padding token first befor getting to the end token
|
||||
if self.extra_padding_token is not None:
|
||||
batch.extend([(self.extra_padding_token, 1.0, 0)] * (self.min_length - len(batch) - 1))
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
|
||||
|
||||
return batched_tokens
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SDXLClipGTokenizer(SDTokenizer):
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(pad_with_end=False, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self):
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
self.t5xxl = T5XXLTokenizer()
|
||||
|
||||
def tokenize_with_weights(self, text: str):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
|
||||
return out
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
||||
out, pooled = self([tokens])
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled
|
||||
output = [out[0:1]]
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
max_length=77,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
textmodel_json_config=None,
|
||||
dtype=None,
|
||||
model_class=CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},
|
||||
layer_norm_hidden_state=True,
|
||||
return_projected_pooled=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
self.max_length = max_length
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (
|
||||
self.layer,
|
||||
self.layer_idx,
|
||||
self.return_projected_pooled,
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
outputs = self.transformer(
|
||||
tokens,
|
||||
intermediate_output=self.layer_idx,
|
||||
final_layer_norm_intermediate=self.layer_norm_hidden_state,
|
||||
)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs[1]
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
class SDXLClipG(SDClipModel):
|
||||
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
|
||||
|
||||
def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
)
|
||||
|
||||
|
||||
class T5XXLModel(SDClipModel):
|
||||
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
|
||||
|
||||
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config=config,
|
||||
dtype=dtype,
|
||||
special_tokens={"end": 1, "pad": 0},
|
||||
model_class=T5,
|
||||
)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
pad_with_end=False,
|
||||
tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
|
||||
has_start_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=77,
|
||||
)
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = torch.nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = torch.nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = torch.nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = torch.nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(
|
||||
self.relative_attention_num_buckets, self.num_heads, device=device
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
# now relative_position is in the range [0, inf)
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(
|
||||
relative_position_if_large,
|
||||
torch.full_like(relative_position_if_large, num_buckets - 1),
|
||||
)
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
if past_bias is not None:
|
||||
mask = past_bias
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(
|
||||
T5LayerSelfAttention(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
)
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
x, past_bias = self.layer[0](x, past_bias)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
vocab_size,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||
self.block = torch.nn.ModuleList(
|
||||
[
|
||||
T5Block(
|
||||
model_dim,
|
||||
inner_dim,
|
||||
ff_dim,
|
||||
num_heads,
|
||||
relative_attention_bias=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids)
|
||||
past_bias = None
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(
|
||||
self.num_layers,
|
||||
config_dict["d_model"],
|
||||
config_dict["d_model"],
|
||||
config_dict["d_ff"],
|
||||
config_dict["num_heads"],
|
||||
config_dict["vocab_size"],
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
@@ -1,609 +0,0 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py
|
||||
|
||||
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.sd3.mmditx import MMDiTX
|
||||
|
||||
#################################################################################################
|
||||
### MMDiT Model Wrapping
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||
|
||||
def __init__(self, shift: float = 1.0):
|
||||
super().__init__()
|
||||
self.shift = shift
|
||||
timesteps = 1000
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def calculate_denoised(
|
||||
self, sigma: torch.Tensor, model_output: torch.Tensor, model_input: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shift=1.0,
|
||||
device=None,
|
||||
dtype=torch.float32,
|
||||
file=None,
|
||||
prefix="",
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
|
||||
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
|
||||
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
|
||||
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in file.keys() else None
|
||||
x_block_self_attn_layers = sorted(
|
||||
[
|
||||
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||
for key in list(filter(re.compile(".*.x_block.attn2.ln_k.weight").match, file.keys()))
|
||||
]
|
||||
)
|
||||
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0],
|
||||
},
|
||||
}
|
||||
self.diffusion_model = MMDiTX(
|
||||
input_size=None,
|
||||
pos_embed_scaling_factor=None,
|
||||
pos_embed_offset=None,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=16,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
adm_in_channels=adm_in_channels,
|
||||
context_embedder_config=context_embedder_config,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
|
||||
def apply_model(
|
||||
self, x: torch.Tensor, sigma: float, c_crossattn: torch.Tensor | None = None, y: torch.Tensor | None = None
|
||||
):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""Helper for applying CFG Scaling to diffusion outputs"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, x, timestep, cond, uncond, cond_scale):
|
||||
# Run cond and uncond in a batch together
|
||||
batched = self.model.apply_model(
|
||||
torch.cat([x, x]),
|
||||
torch.cat([timestep, timestep]),
|
||||
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
|
||||
y=torch.cat([cond["y"], uncond["y"]]),
|
||||
)
|
||||
# Then split and apply CFG Scaling
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
scaled = neg_out + (pos_out - neg_out) * cond_scale
|
||||
return scaled
|
||||
|
||||
|
||||
class SD3LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
"""Quick RGB approximate preview of sd3 latents"""
|
||||
factors = torch.tensor(
|
||||
[
|
||||
[-0.0645, 0.0177, 0.1052],
|
||||
[0.0028, 0.0312, 0.0650],
|
||||
[0.1848, 0.0762, 0.0360],
|
||||
[0.0944, 0.0360, 0.0889],
|
||||
[0.0897, 0.0506, -0.0364],
|
||||
[-0.0020, 0.1203, 0.0284],
|
||||
[0.0855, 0.0118, 0.0283],
|
||||
[-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700],
|
||||
[-0.0412, 0.0281, -0.0039],
|
||||
[0.1106, 0.1171, 0.1220],
|
||||
[-0.0248, 0.0682, -0.0481],
|
||||
[0.0815, 0.0846, 0.1207],
|
||||
[-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456],
|
||||
[-0.1418, -0.1457, -0.1259],
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||
|
||||
latents_ubyte = (
|
||||
((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()
|
||||
).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Samplers
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_euler(model, x, sigmas, extra_args=None):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
for i in tqdm(range(len(sigmas) - 1)):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### VAE
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = None
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = x
|
||||
hidden = self.norm1(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv1(hidden)
|
||||
hidden = self.norm2(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv2(hidden)
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class AttnBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.norm(x)
|
||||
q = self.q(hidden)
|
||||
k = self.k(hidden)
|
||||
v = self.v(hidden)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(
|
||||
lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
hidden = self.proj_out(hidden)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class VAEEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
in_channels=3,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = torch.nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = torch.nn.ModuleList()
|
||||
attn = torch.nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = torch.nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
||||
self.down.append(down)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VAEDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
resolution=256,
|
||||
z_channels=16,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
# upsampling
|
||||
self.up = torch.nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = torch.nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = torch.nn.Module()
|
||||
up.block = block
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
hidden = self.conv_in(z)
|
||||
# middle
|
||||
hidden = self.mid.block_1(hidden)
|
||||
hidden = self.mid.attn_1(hidden)
|
||||
hidden = self.mid.block_2(hidden)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
hidden = self.up[i_level].block[i_block](hidden)
|
||||
if i_level != 0:
|
||||
hidden = self.up[i_level].upsample(hidden)
|
||||
# end
|
||||
hidden = self.norm_out(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv_out(hidden)
|
||||
return hidden
|
||||
|
||||
|
||||
class SDVAE(torch.nn.Module):
|
||||
def __init__(self, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
||||
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def decode(self, latent):
|
||||
return self.decoder(latent)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def encode(self, image):
|
||||
hidden = self.encoder(image)
|
||||
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
@@ -1,426 +0,0 @@
|
||||
# This file was originally copied from:
|
||||
# https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_infer.py
|
||||
|
||||
# NOTE: Must have folder `models` with the following files:
|
||||
# - `clip_g.safetensors` (openclip bigG, same as SDXL)
|
||||
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
|
||||
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
|
||||
# - `sd3_medium.safetensors` (or whichever main MMDiT model file)
|
||||
# Also can have
|
||||
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import sd3_impls
|
||||
import torch
|
||||
from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from sd3_impls import SDVAE, BaseModel, CFGDenoiser, SD3LatentFormat
|
||||
from tqdm import tqdm
|
||||
|
||||
#################################################################################################
|
||||
### Wrappers for model parts
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def load_into(f, model, prefix, device, dtype=None):
|
||||
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
|
||||
for key in f.keys():
|
||||
if key.startswith(prefix) and not key.startswith("loss."):
|
||||
path = key[len(prefix) :].split(".")
|
||||
obj = model
|
||||
for p in path:
|
||||
if obj is list:
|
||||
obj = obj[int(p)]
|
||||
else:
|
||||
obj = getattr(obj, p, None)
|
||||
if obj is None:
|
||||
print(f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model")
|
||||
break
|
||||
if obj is None:
|
||||
continue
|
||||
try:
|
||||
tensor = f.get_tensor(key).to(device=device)
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
obj.requires_grad_(False)
|
||||
obj.set_(tensor)
|
||||
except Exception as e:
|
||||
print(f"Failed to load key '{key}' in safetensors file: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
}
|
||||
|
||||
|
||||
class ClipG:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_g.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
|
||||
class ClipL:
|
||||
def __init__(self):
|
||||
with safe_open("models/clip_l.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = SDClipModel(
|
||||
layer="hidden",
|
||||
layer_idx=-2,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
layer_norm_hidden_state=False,
|
||||
return_projected_pooled=False,
|
||||
textmodel_json_config=CLIPL_CONFIG,
|
||||
)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class T5XXL:
|
||||
def __init__(self):
|
||||
with safe_open("models/t5xxl.safetensors", framework="pt", device="cpu") as f:
|
||||
self.model = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
|
||||
load_into(f, self.model.transformer, "", "cpu", torch.float32)
|
||||
|
||||
|
||||
class SD3:
|
||||
def __init__(self, model, shift, verbose=False):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = BaseModel(
|
||||
shift=shift,
|
||||
file=f,
|
||||
prefix="model.diffusion_model.",
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
verbose=verbose,
|
||||
).eval()
|
||||
load_into(f, self.model, "model.", "cpu", torch.float16)
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, model):
|
||||
with safe_open(model, framework="pt", device="cpu") as f:
|
||||
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
|
||||
prefix = ""
|
||||
if any(k.startswith("first_stage_model.") for k in f.keys()):
|
||||
prefix = "first_stage_model."
|
||||
load_into(f, self.model, prefix, "cpu", torch.float16)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Main inference logic
|
||||
#################################################################################################
|
||||
|
||||
|
||||
# Note: Sigma shift value, publicly released models use 3.0
|
||||
SHIFT = 3.0
|
||||
# Naturally, adjust to the width/height of the model you have
|
||||
WIDTH = 1024
|
||||
HEIGHT = 1024
|
||||
# Pick your prompt
|
||||
PROMPT = "a photo of a cat"
|
||||
# Most models prefer the range of 4-5, but still work well around 7
|
||||
CFG_SCALE = 4.5
|
||||
# Different models want different step counts but most will be good at 50, albeit that's slow to run
|
||||
# sd3_medium is quite decent at 28 steps
|
||||
STEPS = 40
|
||||
# Seed
|
||||
SEED = 23
|
||||
# SEEDTYPE = "fixed"
|
||||
SEEDTYPE = "rand"
|
||||
# SEEDTYPE = "roll"
|
||||
# Actual model file path
|
||||
# MODEL = "models/sd3_medium.safetensors"
|
||||
# MODEL = "models/sd3.5_large_turbo.safetensors"
|
||||
MODEL = "models/sd3.5_large.safetensors"
|
||||
# VAE model file path, or set None to use the same model file
|
||||
VAEFile = None # "models/sd3_vae.safetensors"
|
||||
# Optional init image file path
|
||||
INIT_IMAGE = None
|
||||
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
|
||||
DENOISE = 0.6
|
||||
# Output file path
|
||||
OUTDIR = "outputs"
|
||||
# SAMPLER
|
||||
# SAMPLER = "euler"
|
||||
SAMPLER = "dpmpp_2m"
|
||||
|
||||
|
||||
class SD3Inferencer:
|
||||
def print(self, txt):
|
||||
if self.verbose:
|
||||
print(txt)
|
||||
|
||||
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT, verbose=False):
|
||||
self.verbose = verbose
|
||||
print("Loading tokenizers...")
|
||||
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
|
||||
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
|
||||
# (T5 tokenizer is different though)
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
print("Loading OpenAI CLIP L...")
|
||||
self.clip_l = ClipL()
|
||||
print("Loading OpenCLIP bigG...")
|
||||
self.clip_g = ClipG()
|
||||
print("Loading Google T5-v1-XXL...")
|
||||
self.t5xxl = T5XXL()
|
||||
print(f"Loading SD3 model {os.path.basename(model)}...")
|
||||
self.sd3 = SD3(model, shift, verbose)
|
||||
print("Loading VAE model...")
|
||||
self.vae = VAE(vae or model)
|
||||
print("Models loaded.")
|
||||
|
||||
def get_empty_latent(self, width, height):
|
||||
self.print("Prep an empty latent...")
|
||||
return torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
|
||||
|
||||
def get_sigmas(self, sampling, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
sigs = []
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
sigs.append(sampling.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def get_noise(self, seed, latent):
|
||||
generator = torch.manual_seed(seed)
|
||||
self.print(f"dtype = {latent.dtype}, layout = {latent.layout}, device = {latent.device}")
|
||||
return torch.randn(
|
||||
latent.size(),
|
||||
dtype=torch.float32,
|
||||
layout=latent.layout,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
).to(latent.dtype)
|
||||
|
||||
def get_cond(self, prompt):
|
||||
self.print("Encode prompt...")
|
||||
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
||||
l_out, l_pooled = self.clip_l.model.encode_token_weights(tokens["l"])
|
||||
g_out, g_pooled = self.clip_g.model.encode_token_weights(tokens["g"])
|
||||
t5_out, t5_pooled = self.t5xxl.model.encode_token_weights(tokens["t5xxl"])
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
def max_denoise(self, sigmas):
|
||||
max_sigma = float(self.sd3.model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
def fix_cond(self, cond):
|
||||
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
|
||||
return {"c_crossattn": cond, "y": pooled}
|
||||
|
||||
def do_sampling(
|
||||
self,
|
||||
latent,
|
||||
seed,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler="dpmpp_2m",
|
||||
denoise=1.0,
|
||||
) -> torch.Tensor:
|
||||
self.print("Sampling...")
|
||||
latent = latent.half().cuda()
|
||||
self.sd3.model = self.sd3.model.cuda()
|
||||
noise = self.get_noise(seed, latent).cuda()
|
||||
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
|
||||
sigmas = sigmas[int(steps * (1 - denoise)) :]
|
||||
conditioning = self.fix_cond(conditioning)
|
||||
neg_cond = self.fix_cond(neg_cond)
|
||||
extra_args = {"cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale}
|
||||
noise_scaled = self.sd3.model.model_sampling.noise_scaling(sigmas[0], noise, latent, self.max_denoise(sigmas))
|
||||
sample_fn = getattr(sd3_impls, f"sample_{sampler}")
|
||||
latent = sample_fn(CFGDenoiser(self.sd3.model), noise_scaled, sigmas, extra_args=extra_args)
|
||||
latent = SD3LatentFormat().process_out(latent)
|
||||
self.sd3.model = self.sd3.model.cpu()
|
||||
self.print("Sampling done")
|
||||
return latent
|
||||
|
||||
def vae_encode(self, image) -> torch.Tensor:
|
||||
self.print("Encoding image to latent...")
|
||||
image = image.convert("RGB")
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = np.moveaxis(image_np, 2, 0)
|
||||
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
|
||||
image_torch = torch.from_numpy(batch_images)
|
||||
image_torch = 2.0 * image_torch - 1.0
|
||||
image_torch = image_torch.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
latent = self.vae.model.encode(image_torch).cpu()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
self.print("Encoded")
|
||||
return latent
|
||||
|
||||
def vae_decode(self, latent) -> Image.Image:
|
||||
self.print("Decoding latent to image...")
|
||||
latent = latent.cuda()
|
||||
self.vae.model = self.vae.model.cuda()
|
||||
image = self.vae.model.decode(latent)
|
||||
image = image.float()
|
||||
self.vae.model = self.vae.model.cpu()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
out_image = Image.fromarray(decoded_np)
|
||||
self.print("Decoded")
|
||||
return out_image
|
||||
|
||||
def gen_image(
|
||||
self,
|
||||
prompts=[PROMPT],
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
steps=STEPS,
|
||||
cfg_scale=CFG_SCALE,
|
||||
sampler=SAMPLER,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
out_dir=OUTDIR,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
):
|
||||
latent = self.get_empty_latent(width, height)
|
||||
if init_image:
|
||||
image_data = Image.open(init_image)
|
||||
image_data = image_data.resize((width, height), Image.LANCZOS)
|
||||
latent = self.vae_encode(image_data)
|
||||
latent = SD3LatentFormat().process_in(latent)
|
||||
neg_cond = self.get_cond("")
|
||||
seed_num = None
|
||||
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True)
|
||||
for i, prompt in pbar:
|
||||
if seed_type == "roll":
|
||||
seed_num = seed if seed_num is None else seed_num + 1
|
||||
elif seed_type == "rand":
|
||||
seed_num = torch.randint(0, 100000, (1,)).item()
|
||||
else: # fixed
|
||||
seed_num = seed
|
||||
conditioning = self.get_cond(prompt)
|
||||
sampled_latent = self.do_sampling(
|
||||
latent,
|
||||
seed_num,
|
||||
conditioning,
|
||||
neg_cond,
|
||||
steps,
|
||||
cfg_scale,
|
||||
sampler,
|
||||
denoise if init_image else 1.0,
|
||||
)
|
||||
image = self.vae_decode(sampled_latent)
|
||||
save_path = os.path.join(out_dir, f"{i:06d}.png")
|
||||
self.print(f"Will save to {save_path}")
|
||||
image.save(save_path)
|
||||
self.print("Done")
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
"sd3_medium": {
|
||||
"shift": 1.0,
|
||||
"cfg": 5.0,
|
||||
"steps": 50,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large": {
|
||||
"shift": 3.0,
|
||||
"cfg": 4.5,
|
||||
"steps": 40,
|
||||
"sampler": "dpmpp_2m",
|
||||
},
|
||||
"sd3.5_large_turbo": {"shift": 3.0, "cfg": 1.0, "steps": 4, "sampler": "euler"},
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(
|
||||
prompt=PROMPT,
|
||||
model=MODEL,
|
||||
out_dir=OUTDIR,
|
||||
postfix=None,
|
||||
seed=SEED,
|
||||
seed_type=SEEDTYPE,
|
||||
sampler=None,
|
||||
steps=None,
|
||||
cfg=None,
|
||||
shift=None,
|
||||
width=WIDTH,
|
||||
height=HEIGHT,
|
||||
vae=VAEFile,
|
||||
init_image=INIT_IMAGE,
|
||||
denoise=DENOISE,
|
||||
verbose=False,
|
||||
):
|
||||
steps = steps or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["steps"]
|
||||
cfg = cfg or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["cfg"]
|
||||
shift = shift or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["shift"]
|
||||
sampler = sampler or CONFIGS[os.path.splitext(os.path.basename(model))[0]]["sampler"]
|
||||
|
||||
inferencer = SD3Inferencer()
|
||||
inferencer.load(model, vae, shift, verbose)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
if os.path.splitext(prompt)[-1] == ".txt":
|
||||
with open(prompt, "r") as f:
|
||||
prompts = [l.strip() for l in f.readlines()]
|
||||
else:
|
||||
prompts = [prompt]
|
||||
|
||||
out_dir = os.path.join(
|
||||
out_dir,
|
||||
os.path.splitext(os.path.basename(model))[0],
|
||||
os.path.splitext(os.path.basename(prompt))[0][:50]
|
||||
+ (postfix or datetime.datetime.now().strftime("_%Y-%m-%dT%H-%M-%S")),
|
||||
)
|
||||
print(f"Saving images to {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=False)
|
||||
|
||||
inferencer.gen_image(
|
||||
prompts,
|
||||
width,
|
||||
height,
|
||||
steps,
|
||||
cfg,
|
||||
sampler,
|
||||
seed,
|
||||
seed_type,
|
||||
out_dir,
|
||||
init_image,
|
||||
denoise,
|
||||
)
|
||||
|
||||
|
||||
fire.Fire(main)
|
||||
@@ -1,72 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.sd3.mmditx import MMDiTX
|
||||
from invokeai.backend.sd3.sd3_impls import ModelSamplingDiscreteFlow
|
||||
|
||||
|
||||
class ContextEmbedderConfig(TypedDict):
|
||||
target: Literal["torch.nn.Linear"]
|
||||
params: dict[str, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sd3MMDiTXParams:
|
||||
patch_size: int
|
||||
depth: int
|
||||
num_patches: int
|
||||
pos_embed_max_size: int
|
||||
adm_in_channels: int
|
||||
context_shape: tuple[int, int]
|
||||
qk_norm: Literal["rms", None]
|
||||
x_block_self_attn_layers: list[int]
|
||||
context_embedder_config: ContextEmbedderConfig
|
||||
|
||||
|
||||
class Sd3MMDiTX(torch.nn.Module):
|
||||
"""This class is based closely on
|
||||
https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py#L53
|
||||
but has more standard model loading semantics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Sd3MMDiTXParams,
|
||||
shift: float = 1.0,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.diffusion_model = MMDiTX(
|
||||
input_size=None,
|
||||
pos_embed_scaling_factor=None,
|
||||
pos_embed_offset=None,
|
||||
pos_embed_max_size=params.pos_embed_max_size,
|
||||
patch_size=params.patch_size,
|
||||
in_channels=16,
|
||||
depth=params.depth,
|
||||
num_patches=params.num_patches,
|
||||
adm_in_channels=params.adm_in_channels,
|
||||
context_embedder_config=params.context_embedder_config,
|
||||
qk_norm=params.qk_norm,
|
||||
x_block_self_attn_layers=params.x_block_self_attn_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
|
||||
def apply_model(self, x: torch.Tensor, sigma: torch.Tensor, c_crossattn: torch.Tensor, y: torch.Tensor):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, x: torch.Tensor, sigma: float, c_crossattn: torch.Tensor, y: torch.Tensor):
|
||||
return self.apply_model(x=x, sigma=sigma, c_crossattn=c_crossattn, y=y)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
@@ -1,70 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
from invokeai.backend.sd3.sd3_mmditx import ContextEmbedderConfig, Sd3MMDiTXParams
|
||||
|
||||
|
||||
def is_sd3_checkpoint(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an SD3 checkpoint like this one?:
|
||||
https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors
|
||||
|
||||
Note that this checkpoint format contains both the VAE and the MMDiTX model.
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely a SD3 checkpoint.
|
||||
expected_keys = {
|
||||
# VAE decoder and encoder keys.
|
||||
"first_stage_model.decoder.conv_in.bias",
|
||||
"first_stage_model.decoder.conv_in.weight",
|
||||
"first_stage_model.encoder.conv_in.bias",
|
||||
"first_stage_model.encoder.conv_in.weight",
|
||||
# MMDiTX keys.
|
||||
"model.diffusion_model.final_layer.linear.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.ln_k.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.ln_q.weight",
|
||||
}
|
||||
|
||||
return expected_keys.issubset(sd.keys())
|
||||
|
||||
|
||||
def infer_sd3_mmditx_params(sd: Dict[str, Any], prefix: str = "model.diffusion_model.") -> Sd3MMDiTXParams:
|
||||
"""Infer the MMDiTX model parameters from the state dict.
|
||||
|
||||
This logic is based on:
|
||||
https://github.com/Stability-AI/sd3.5/blob/19bf11c4e1e37324c5aa5a61f010d4127848a09c/sd3_impls.py#L68-L88
|
||||
"""
|
||||
patch_size = sd[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = sd[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = sd[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = sd[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = sd[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in sd else None
|
||||
x_block_self_attn_layers = sorted(
|
||||
[
|
||||
int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])
|
||||
for key in list(filter(re.compile(".*.x_block.attn2.ln_k.weight").match, sd.keys()))
|
||||
]
|
||||
)
|
||||
|
||||
context_embedder_config: ContextEmbedderConfig = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0],
|
||||
},
|
||||
}
|
||||
return Sd3MMDiTXParams(
|
||||
patch_size=patch_size,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
adm_in_channels=adm_in_channels,
|
||||
context_shape=context_shape,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
context_embedder_config=context_embedder_config,
|
||||
)
|
||||
@@ -499,6 +499,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||
|
||||
# Hack: force compatibility with irregular resolutions by padding the feature map with zeros
|
||||
for idx, tensor in enumerate(accum_adapter_state):
|
||||
# The tensor size is supposed to be some integer downscale factor of the latents size.
|
||||
# Internally, the unet will pad the latents before downscaling between levels when it is no longer divisible by its downscale factor.
|
||||
# If the latent size does not scale down evenly, we need to pad the tensor so that it matches the the downscaled padded latents later on.
|
||||
scale_factor = latents.size()[-1] // tensor.size()[-1]
|
||||
required_padding_width = math.ceil(latents.size()[-1] / scale_factor) - tensor.size()[-1]
|
||||
required_padding_height = math.ceil(latents.size()[-2] / scale_factor) - tensor.size()[-2]
|
||||
tensor = torch.nn.functional.pad(
|
||||
tensor,
|
||||
(0, required_padding_width, 0, required_padding_height, 0, 0, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
accum_adapter_state[idx] = tensor
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
# Handle inpainting models.
|
||||
|
||||
@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SD3ConditioningInfo:
|
||||
clip_l_pooled_embeds: torch.Tensor
|
||||
clip_l_embeds: torch.Tensor
|
||||
clip_g_pooled_embeds: torch.Tensor
|
||||
clip_g_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor | None
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
|
||||
if self.t5_embeds is not None:
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
| List[FLUXConditioningInfo]
|
||||
| List[SD3ConditioningInfo]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreviewExt(ExtensionBase):
|
||||
def initial_preview(self, ctx: DenoiseContext):
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
step=0,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlNetMixin
|
||||
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -32,7 +32,9 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
# NOTE(ryand): I'm not the origina author of this code, but for future reference, it appears that this class was copied
|
||||
# from diffusers in order to add support for the encoder_attention_mask argument.
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
12
invokeai/backend/util/prefix_logger_adapter.py
Normal file
12
invokeai/backend/util/prefix_logger_adapter.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import logging
|
||||
from typing import Any, MutableMapping
|
||||
|
||||
|
||||
# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855
|
||||
class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore
|
||||
def __init__(self, logger: logging.Logger, prefix: str):
|
||||
super().__init__(logger, {})
|
||||
self.prefix = prefix
|
||||
|
||||
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]:
|
||||
return f"[{self.prefix}] {msg}", kwargs
|
||||
@@ -9,6 +9,7 @@ const config: KnipConfig = {
|
||||
'src/services/api/schema.ts',
|
||||
'src/features/nodes/types/v1/**',
|
||||
'src/features/nodes/types/v2/**',
|
||||
'src/features/parameters/types/parameterSchemas.ts',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// TODO(psyche): restore HRF functionality?
|
||||
|
||||
@@ -52,13 +52,13 @@
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"@atlaskit/pragmatic-drag-and-drop": "^1.4.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^1.4.0",
|
||||
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@dagrejs/graphlib": "^2.2.4",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
"@dnd-kit/sortable": "^8.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@fontsource-variable/inter": "^5.1.0",
|
||||
"@invoke-ai/ui-library": "^0.0.42",
|
||||
"@invoke-ai/ui-library": "^0.0.44",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user