mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 07:28:06 -05:00
Compare commits
493 Commits
v5.3.0
...
ryan/lora-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ed6e65a6e | ||
|
|
52c9646f84 | ||
|
|
7662f0522b | ||
|
|
e50fe69839 | ||
|
|
5a9f884620 | ||
|
|
edc72d1739 | ||
|
|
23f521dc7c | ||
|
|
3d6b93efdd | ||
|
|
3f28d3afad | ||
|
|
9353bfbdd6 | ||
|
|
93f2bc6118 | ||
|
|
9019026d6d | ||
|
|
c195b326ec | ||
|
|
2f460d2a45 | ||
|
|
4473cba512 | ||
|
|
4c94d41fa9 | ||
|
|
4036244ee9 | ||
|
|
d06232d9ba | ||
|
|
bacbdfb8fc | ||
|
|
59f42f4682 | ||
|
|
a636ac2899 | ||
|
|
bd478360d9 | ||
|
|
ac0db07649 | ||
|
|
b7132ce9e7 | ||
|
|
90f30e7748 | ||
|
|
6b86a66bc7 | ||
|
|
aa97e626e9 | ||
|
|
c90736093f | ||
|
|
0bff4ace1b | ||
|
|
5eb382074e | ||
|
|
46aa930526 | ||
|
|
3305bad0c2 | ||
|
|
13703d8f55 | ||
|
|
60d838d0a5 | ||
|
|
2a157a44bf | ||
|
|
d61b5833c2 | ||
|
|
c094838c6a | ||
|
|
2d334c8dd8 | ||
|
|
a6be26e174 | ||
|
|
f8c7adddd0 | ||
|
|
17da1d92e9 | ||
|
|
1cc57a4854 | ||
|
|
3993fae331 | ||
|
|
1446526d55 | ||
|
|
62c024e725 | ||
|
|
1e92bb4e94 | ||
|
|
db6398fdf6 | ||
|
|
ebd73a2ac2 | ||
|
|
8ee95cab00 | ||
|
|
d1184201a8 | ||
|
|
5887891654 | ||
|
|
765ca4e004 | ||
|
|
159b00a490 | ||
|
|
3fbf6f2d2a | ||
|
|
931fca7cd1 | ||
|
|
db84a3a5d4 | ||
|
|
ca8313e805 | ||
|
|
df849035ee | ||
|
|
8d97fe69ca | ||
|
|
9044e53a9b | ||
|
|
6012b0f912 | ||
|
|
bb0ed5dc8a | ||
|
|
021552fd81 | ||
|
|
be73dbba92 | ||
|
|
db9c0cad7c | ||
|
|
54b7f9a063 | ||
|
|
7d488a5352 | ||
|
|
4d7667f63d | ||
|
|
08704ee8ec | ||
|
|
5910892c33 | ||
|
|
46a09d9e90 | ||
|
|
df0c7d73f3 | ||
|
|
3905c97e32 | ||
|
|
0be796a808 | ||
|
|
7dd33b0f39 | ||
|
|
484aaf1595 | ||
|
|
c276b60af9 | ||
|
|
5d8dd6e26e | ||
|
|
5bca68d873 | ||
|
|
64364e7911 | ||
|
|
6565cea039 | ||
|
|
3ebd8d6c07 | ||
|
|
e970185161 | ||
|
|
fa5653cdf7 | ||
|
|
9a7b000995 | ||
|
|
3a27242838 | ||
|
|
8cfb032051 | ||
|
|
06a9d4e2b2 | ||
|
|
ed46acee79 | ||
|
|
b54463d294 | ||
|
|
faee79dc95 | ||
|
|
965cd76e33 | ||
|
|
e5e8cbf34c | ||
|
|
3412a52594 | ||
|
|
e01f66b026 | ||
|
|
53abdde242 | ||
|
|
94c088300f | ||
|
|
3741a6f5e0 | ||
|
|
059336258f | ||
|
|
2c23b8414c | ||
|
|
271cc52c80 | ||
|
|
20356c0746 | ||
|
|
e44458609f | ||
|
|
69d86a7696 | ||
|
|
56db1a9292 | ||
|
|
cf50e5eeee | ||
|
|
c9c07968d2 | ||
|
|
97d0757176 | ||
|
|
0f51b677a9 | ||
|
|
56ca94c3a9 | ||
|
|
28d169f859 | ||
|
|
92f71d99ee | ||
|
|
0764c02b1d | ||
|
|
081c7569fe | ||
|
|
20f6532ee8 | ||
|
|
b9e8910478 | ||
|
|
ded8391e3c | ||
|
|
e9dd2c396a | ||
|
|
0d86de0cb5 | ||
|
|
bad1149504 | ||
|
|
fda7aaa7ca | ||
|
|
85c616fa34 | ||
|
|
549f4e9794 | ||
|
|
ef8ededd2f | ||
|
|
1948ffe106 | ||
|
|
c70f4404c4 | ||
|
|
b157ae928c | ||
|
|
7a0871992d | ||
|
|
b38e2e14f4 | ||
|
|
7c0e70ec84 | ||
|
|
a89ae9d2bf | ||
|
|
ad1fcb3f07 | ||
|
|
87d74b910b | ||
|
|
7ad1c297a4 | ||
|
|
fbc629faa6 | ||
|
|
7baa6b3c09 | ||
|
|
53d482bade | ||
|
|
5aca04b51b | ||
|
|
ea8787c8ff | ||
|
|
cead2c4445 | ||
|
|
f76ac1808c | ||
|
|
f01210861b | ||
|
|
f757f23ef0 | ||
|
|
872a6ef209 | ||
|
|
4267e5ffc4 | ||
|
|
a69c5ff9ef | ||
|
|
3ebd8d7d1b | ||
|
|
1fd80d54a4 | ||
|
|
991f63e455 | ||
|
|
6a1efd3527 | ||
|
|
0eadc0dd9e | ||
|
|
481423d678 | ||
|
|
89ede0aef3 | ||
|
|
359bdee9c6 | ||
|
|
0e6fba3763 | ||
|
|
652502d7a6 | ||
|
|
91d981a49e | ||
|
|
24f61d21b2 | ||
|
|
eb9a4177c5 | ||
|
|
3c43351a5b | ||
|
|
b1359b6dff | ||
|
|
bddccf6d2f | ||
|
|
21ffaab2a2 | ||
|
|
1e969f938f | ||
|
|
9c6c86ee4f | ||
|
|
6b53a48b48 | ||
|
|
c813fa3fc0 | ||
|
|
a08e61184a | ||
|
|
a0d62a5f41 | ||
|
|
616c0f11e1 | ||
|
|
e1626a4e49 | ||
|
|
6ab891a319 | ||
|
|
492de41316 | ||
|
|
c064efc866 | ||
|
|
1a0885bfb1 | ||
|
|
e8b202d0a5 | ||
|
|
c6fc82f756 | ||
|
|
9a77e951d2 | ||
|
|
8bd4207a27 | ||
|
|
0bb601aaf7 | ||
|
|
2da25a0043 | ||
|
|
51d0931898 | ||
|
|
357b68d1ba | ||
|
|
d9ddb6c32e | ||
|
|
ad02a99a83 | ||
|
|
b707dafc7b | ||
|
|
02906c8f5d | ||
|
|
8538e508f1 | ||
|
|
8c333ffd14 | ||
|
|
72ace5fdff | ||
|
|
9b7583fc84 | ||
|
|
989eee338e | ||
|
|
acc3d7b91b | ||
|
|
49de868658 | ||
|
|
b1702c7d90 | ||
|
|
e49e19ea13 | ||
|
|
c9f91f391e | ||
|
|
4cb6b2b701 | ||
|
|
7d132ea148 | ||
|
|
1088accd91 | ||
|
|
8d237d8f8b | ||
|
|
0c86a3232d | ||
|
|
dbfb0359cb | ||
|
|
b4c2aa596b | ||
|
|
87e89b7995 | ||
|
|
9b089430e2 | ||
|
|
f2b0025958 | ||
|
|
4b390906bc | ||
|
|
c5b8efe03b | ||
|
|
4d08d00ad8 | ||
|
|
9b0130262b | ||
|
|
878093f64e | ||
|
|
d5ff7ef250 | ||
|
|
f36583f866 | ||
|
|
829bc1bc7d | ||
|
|
17c7b57145 | ||
|
|
6a12189542 | ||
|
|
96a31a5563 | ||
|
|
067747eca9 | ||
|
|
c7878fddc6 | ||
|
|
54c51e0a06 | ||
|
|
1640ea0298 | ||
|
|
0c32ae9775 | ||
|
|
fdb8ca5165 | ||
|
|
571faf6d7c | ||
|
|
bdbdb22b74 | ||
|
|
9bbb5644af | ||
|
|
e90ad19f22 | ||
|
|
0ba11e8f73 | ||
|
|
1cf7600f5b | ||
|
|
4f9d12b872 | ||
|
|
68c3b0649b | ||
|
|
8ef8bd4261 | ||
|
|
50897ba066 | ||
|
|
3510643870 | ||
|
|
ca9cb1c9ef | ||
|
|
b89caa02bd | ||
|
|
eaf4e08c44 | ||
|
|
fb19621361 | ||
|
|
9179619077 | ||
|
|
13cb5f0ba2 | ||
|
|
7e52fc1c17 | ||
|
|
7f60a4a282 | ||
|
|
3f880496f7 | ||
|
|
f05efd3270 | ||
|
|
79eb8172b6 | ||
|
|
7732b5d478 | ||
|
|
a2a1934b66 | ||
|
|
dff6570078 | ||
|
|
04e4fb63af | ||
|
|
83609d5008 | ||
|
|
2618ed0ae7 | ||
|
|
bb3cedddd5 | ||
|
|
5b3e1593ca | ||
|
|
2d08078a7d | ||
|
|
75acece1f1 | ||
|
|
a9db2ffefd | ||
|
|
cdd148b4d1 | ||
|
|
730fabe2de | ||
|
|
6c59790a7f | ||
|
|
0e6cb91863 | ||
|
|
a0fefcd43f | ||
|
|
c37251d6f7 | ||
|
|
2854210162 | ||
|
|
5545b980af | ||
|
|
0c9434c464 | ||
|
|
8771de917d | ||
|
|
122946ef4c | ||
|
|
2d974f670c | ||
|
|
75f0da9c35 | ||
|
|
5df3c00e28 | ||
|
|
b049880502 | ||
|
|
e5293fdd1a | ||
|
|
8883775762 | ||
|
|
cfadb313d2 | ||
|
|
b5cadd9a1a | ||
|
|
5361b6e014 | ||
|
|
ff346172af | ||
|
|
92f660018b | ||
|
|
1afc2cba4e | ||
|
|
ee8359242c | ||
|
|
f0c80a8d7a | ||
|
|
8da9e7c1f6 | ||
|
|
6d7a486e5b | ||
|
|
57122c6aa3 | ||
|
|
54abd8d4d1 | ||
|
|
06283cffed | ||
|
|
27fa0e1140 | ||
|
|
533d48abdb | ||
|
|
6845cae4c9 | ||
|
|
31c9acb1fa | ||
|
|
fb5e462300 | ||
|
|
2f3abc29b1 | ||
|
|
c5c071f285 | ||
|
|
93a3ed56e7 | ||
|
|
406fc58889 | ||
|
|
cf67d084fd | ||
|
|
d4a95af14f | ||
|
|
8c8e7102c2 | ||
|
|
b6b9ea9d70 | ||
|
|
63126950bc | ||
|
|
29d63d5dea | ||
|
|
a5f8c23dee | ||
|
|
7bb4ea57c6 | ||
|
|
75dc961bcb | ||
|
|
a9a1f6ef21 | ||
|
|
aa40161f26 | ||
|
|
6efa812874 | ||
|
|
8a683f5a3c | ||
|
|
f4b0b6a93d | ||
|
|
1337c33ad3 | ||
|
|
2f6b035138 | ||
|
|
4f9ae44472 | ||
|
|
c682330852 | ||
|
|
c064257759 | ||
|
|
8a4c629576 | ||
|
|
496b02a3bc | ||
|
|
7b5efc2203 | ||
|
|
a01d44f813 | ||
|
|
63fb3a15e9 | ||
|
|
4d0837541b | ||
|
|
999809b4c7 | ||
|
|
c452edfb9f | ||
|
|
ad2cdbd8a2 | ||
|
|
f15c24bfa7 | ||
|
|
d1f653f28c | ||
|
|
244465d3a6 | ||
|
|
c6236ab70c | ||
|
|
644d5cb411 | ||
|
|
bb0a630416 | ||
|
|
2148ae9287 | ||
|
|
42d242609c | ||
|
|
fd0a52392b | ||
|
|
e64415d59a | ||
|
|
1871e0bdbf | ||
|
|
3ae9a965c2 | ||
|
|
85932e35a7 | ||
|
|
41b07a56cc | ||
|
|
54064c0cb8 | ||
|
|
68284b37fa | ||
|
|
ae5bc6f5d6 | ||
|
|
6dc16c9f54 | ||
|
|
faa9ac4e15 | ||
|
|
d0460849b0 | ||
|
|
bed3c2dd77 | ||
|
|
916ddd17d7 | ||
|
|
accfa7407f | ||
|
|
908db31e48 | ||
|
|
b70f632b26 | ||
|
|
d07a6385ab | ||
|
|
68df612fa1 | ||
|
|
3b96c79461 | ||
|
|
89bda5b983 | ||
|
|
22bff1fb22 | ||
|
|
55ba6488d1 | ||
|
|
2d78859171 | ||
|
|
3a661bac34 | ||
|
|
bb8a02de18 | ||
|
|
78155344f6 | ||
|
|
391a24b0f6 | ||
|
|
e75903389f | ||
|
|
27567052f2 | ||
|
|
6f447f7169 | ||
|
|
8b370cc182 | ||
|
|
af583d2971 | ||
|
|
0ebe8fb1bd | ||
|
|
befb629f46 | ||
|
|
874d67cb37 | ||
|
|
19f7a1295a | ||
|
|
78bd605617 | ||
|
|
b87f4e59a5 | ||
|
|
1eca4f12c8 | ||
|
|
f1de11d6bf | ||
|
|
9361ed9d70 | ||
|
|
ebabf4f7a8 | ||
|
|
606f3321f5 | ||
|
|
3970aa30fb | ||
|
|
678436e07c | ||
|
|
c620581699 | ||
|
|
c331d42ce4 | ||
|
|
1ac9b502f1 | ||
|
|
3fa478a12f | ||
|
|
2d86298b7f | ||
|
|
009cdb714c | ||
|
|
9d3f5427b4 | ||
|
|
e4b17f019a | ||
|
|
586c00bc02 | ||
|
|
0f11fda65a | ||
|
|
3e75331ef7 | ||
|
|
be133408ac | ||
|
|
7e1e0d6928 | ||
|
|
cd3d8df5a8 | ||
|
|
24d3c22017 | ||
|
|
b0d37f4e51 | ||
|
|
3559124674 | ||
|
|
6c33e02141 | ||
|
|
8cf94d602f | ||
|
|
016a6f182f | ||
|
|
6fbc019142 | ||
|
|
26f95d6a97 | ||
|
|
40f7b0d171 | ||
|
|
4904700751 | ||
|
|
83538c4b2b | ||
|
|
eb7b559529 | ||
|
|
4945465cf0 | ||
|
|
7eed7282a9 | ||
|
|
47f0781822 | ||
|
|
88b8e3e3d5 | ||
|
|
47c3ab9214 | ||
|
|
d6d436b59c | ||
|
|
6ff7057967 | ||
|
|
e032ab1179 | ||
|
|
65bddfcd93 | ||
|
|
2d3ce418dd | ||
|
|
548d72f7b9 | ||
|
|
19837a0f29 | ||
|
|
483b65a1dc | ||
|
|
b85931c7ab | ||
|
|
9225f47338 | ||
|
|
bccac5e4a6 | ||
|
|
7cb07fdc04 | ||
|
|
b137450026 | ||
|
|
dc5090469a | ||
|
|
e0ae2ace89 | ||
|
|
269faae04b | ||
|
|
e282acd41c | ||
|
|
a266668348 | ||
|
|
3bb3e142fc | ||
|
|
6ac6d70a22 | ||
|
|
b0acf33ba5 | ||
|
|
b3eb64b64c | ||
|
|
95f8ab1a29 | ||
|
|
4e043384db | ||
|
|
0f5df8ba17 | ||
|
|
2826ab48a2 | ||
|
|
7ff1b635c8 | ||
|
|
dfb5e8b5e5 | ||
|
|
7259da799c | ||
|
|
965069fce1 | ||
|
|
90232806d9 | ||
|
|
81bc153399 | ||
|
|
c63e526f99 | ||
|
|
2b74263007 | ||
|
|
d3a82f7119 | ||
|
|
291c5a0341 | ||
|
|
bcb41399ca | ||
|
|
6f0f53849b | ||
|
|
4e7d63761a | ||
|
|
198c84105d | ||
|
|
2453b9f443 | ||
|
|
b091aca986 | ||
|
|
8f02ce54a0 | ||
|
|
f4b7c63002 | ||
|
|
a4629280b5 | ||
|
|
855fb007da | ||
|
|
d805b52c1f | ||
|
|
2ea55685bb | ||
|
|
bd6ff3deaa | ||
|
|
82dd53ec88 | ||
|
|
71d749541d | ||
|
|
48a57fc4b9 | ||
|
|
530e0910fc | ||
|
|
2fdf8fc0a2 | ||
|
|
91db9c9300 | ||
|
|
bc42205593 | ||
|
|
2e3cba6416 | ||
|
|
7852aacd11 | ||
|
|
6cccd67ecd | ||
|
|
a7a89c9de1 | ||
|
|
5ca8eed89e | ||
|
|
c885c3c9a6 | ||
|
|
d81c38c350 | ||
|
|
92d5b73215 | ||
|
|
097e92db6a | ||
|
|
84c6209a45 | ||
|
|
107e48808a | ||
|
|
47168b5505 | ||
|
|
58152ec981 | ||
|
|
c74afbf332 | ||
|
|
7cdda00a54 | ||
|
|
a74282bce6 | ||
|
|
107f048c7a | ||
|
|
a2486a5f06 | ||
|
|
07ab116efb | ||
|
|
1a13af3c7a | ||
|
|
f2966a2594 | ||
|
|
58bb97e3c6 | ||
|
|
34569a2410 | ||
|
|
a84aa5c049 | ||
|
|
acfa9c87ef | ||
|
|
f245d8e429 | ||
|
|
62cf0f54e0 | ||
|
|
5f015e76ba |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -19,3 +19,4 @@
|
||||
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
|
||||
- [ ] _Tests added / updated (if applicable)_
|
||||
- [ ] _Documentation added / updated (if applicable)_
|
||||
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
|
||||
|
||||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Only the latest version of Invoke will receive security updates.
|
||||
We do not currently maintain multiple versions of the application with updates.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a vulnerability, contact the Invoke team directly at security@invoke.ai
|
||||
|
||||
At this time, we do not maintain a formal bug bounty program.
|
||||
|
||||
You can also share identified security issues with our team on huntr.com
|
||||
@@ -2,29 +2,42 @@
|
||||
|
||||
## Builder stage
|
||||
|
||||
FROM library/ubuntu:23.04 AS builder
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
git \
|
||||
python3-venv \
|
||||
python3-pip \
|
||||
build-essential
|
||||
build-essential \
|
||||
git
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
ARG TARGETPLATFORM="linux/amd64"
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt
|
||||
USER ubuntu
|
||||
|
||||
# Install python and create the venv
|
||||
RUN uv python install ${PYTHON_VERSION} && \
|
||||
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
COPY invokeai ./invokeai
|
||||
COPY pyproject.toml ./
|
||||
|
||||
@@ -32,25 +45,18 @@ COPY pyproject.toml ./
|
||||
# the local working copy can be bind-mounted into the image
|
||||
# at path defined by ${INVOKEAI_SRC}
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is default
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m venv ${VIRTUAL_ENV} &&\
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
|
||||
fi &&\
|
||||
fi && \
|
||||
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
|
||||
|
||||
# xformers + triton fails to install on arm64
|
||||
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
||||
pip install $extra_index_url_arg -e ".[xformers]"; \
|
||||
else \
|
||||
pip install $extra_index_url_arg -e "."; \
|
||||
fi
|
||||
|
||||
# #### Build the Web UI ------------------------------------
|
||||
#### Build the Web UI ------------------------------------
|
||||
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
@@ -66,7 +72,7 @@ RUN npx vite build
|
||||
|
||||
#### Runtime stage ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:23.04 AS runtime
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
@@ -83,17 +89,16 @@ RUN apt update && apt install -y --no-install-recommends \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1-mesa-glx \
|
||||
python3-venv \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
@@ -101,6 +106,14 @@ ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
# Install `uv` for package management
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
@@ -115,7 +128,7 @@ WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python3 -c "from patchmatch import patch_match"
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ set -e -o pipefail
|
||||
|
||||
USER_ID=${CONTAINER_UID:-1000}
|
||||
USER=ubuntu
|
||||
# if the user does not exist, create it. It is expected to be present on ubuntu >=24.x
|
||||
_=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
|
||||
# ensure the UID is correct
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
@@ -36,6 +39,8 @@ fi
|
||||
mkdir -p "${INVOKEAI_ROOT}"
|
||||
chown --recursive ${USER} "${INVOKEAI_ROOT}" || true
|
||||
cd "${INVOKEAI_ROOT}"
|
||||
export HF_HOME=${HF_HOME:-$INVOKEAI_ROOT/.cache/huggingface}
|
||||
export MPLCONFIGDIR=${MPLCONFIGDIR:-$INVOKEAI_ROOT/.matplotlib}
|
||||
|
||||
# Run the CMD as the Container User (not root).
|
||||
exec gosu ${USER} "$@"
|
||||
|
||||
@@ -50,7 +50,7 @@ Applications are built on top of the invoke framework. They should construct `in
|
||||
|
||||
### Web UI
|
||||
|
||||
The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such:
|
||||
The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/invokeai/frontend` and the backend code is found in `/invokeai/app/api_app.py` and `/invokeai/app/api/`. The code is further organized as such:
|
||||
|
||||
| Component | Description |
|
||||
| --- | --- |
|
||||
@@ -62,7 +62,7 @@ The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.t
|
||||
|
||||
### CLI
|
||||
|
||||
The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`.
|
||||
The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/invokeai/frontend/cli`.
|
||||
|
||||
## Invoke
|
||||
|
||||
@@ -70,7 +70,7 @@ The Invoke framework provides the interface to the underlying AI systems and is
|
||||
|
||||
### Invoker
|
||||
|
||||
The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services:
|
||||
The invoker (`/invokeai/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services:
|
||||
- **invocation services**, which are used by invocations to interact with core functionality.
|
||||
- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue.
|
||||
|
||||
@@ -82,12 +82,12 @@ The session graph does not support looping. This is left as an application probl
|
||||
|
||||
### Invocations
|
||||
|
||||
Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations.
|
||||
Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/invokeai/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations.
|
||||
|
||||
### Services
|
||||
|
||||
Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import).
|
||||
Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/invokeai/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import).
|
||||
|
||||
## AI Core
|
||||
|
||||
The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`).
|
||||
The AI Core is represented by the rest of the code base (i.e. the code outside of `/invokeai/app/`).
|
||||
|
||||
@@ -287,8 +287,8 @@ new Invocation ready to be used.
|
||||
|
||||
Once you've created a Node, the next step is to share it with the community! The
|
||||
best way to do this is to submit a Pull Request to add the Node to the
|
||||
[Community Nodes](nodes/communityNodes) list. If you're not sure how to do that,
|
||||
take a look a at our [contributing nodes overview](contributingNodes).
|
||||
[Community Nodes](../nodes/communityNodes.md) list. If you're not sure how to do that,
|
||||
take a look a at our [contributing nodes overview](../nodes/contributingNodes.md).
|
||||
|
||||
## Advanced
|
||||
|
||||
|
||||
@@ -9,20 +9,20 @@ model. These are the:
|
||||
configuration information. Among other things, the record service
|
||||
tracks the type of the model, its provenance, and where it can be
|
||||
found on disk.
|
||||
|
||||
|
||||
* _ModelInstallServiceBase_ A service for installing models to
|
||||
disk. It uses `DownloadQueueServiceBase` to download models and
|
||||
their metadata, and `ModelRecordServiceBase` to store that
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
|
||||
* _DownloadQueueServiceBase_
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
queue has special methods for downloading repo_id folders from
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
|
||||
* _ModelLoadServiceBase_
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
@@ -207,9 +207,9 @@ for use in the InvokeAI web server. Its signature is:
|
||||
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
@@ -363,7 +363,7 @@ functionality:
|
||||
|
||||
* Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
@@ -371,21 +371,21 @@ functionality:
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only HuggingFace).
|
||||
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
A default installer is created at InvokeAI api startup time and stored
|
||||
@@ -461,7 +461,7 @@ revision.
|
||||
`config` is an optional dict of values that will override the
|
||||
autoprobed values for model type, base, scheduler prediction type, and
|
||||
so forth. See [Model configuration and
|
||||
probing](#Model-configuration-and-probing) for details.
|
||||
probing](#model-configuration-and-probing) for details.
|
||||
|
||||
`access_token` is an optional access token for accessing resources
|
||||
that need authentication.
|
||||
@@ -494,7 +494,7 @@ source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', ac
|
||||
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
|
||||
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
@@ -504,7 +504,7 @@ for source in sources:
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# InvokeAI Backend Tests
|
||||
|
||||
We use `pytest` to run the backend python tests. (See [pyproject.toml](/pyproject.toml) for the default `pytest` options.)
|
||||
We use `pytest` to run the backend python tests. (See [pyproject.toml](https://github.com/invoke-ai/InvokeAI/blob/main/pyproject.toml) for the default `pytest` options.)
|
||||
|
||||
## Fast vs. Slow
|
||||
All tests are categorized as either 'fast' (no test annotation) or 'slow' (annotated with the `@pytest.mark.slow` decorator).
|
||||
@@ -33,7 +33,7 @@ pytest tests -m ""
|
||||
|
||||
## Test Organization
|
||||
|
||||
All backend tests are in the [`tests/`](/tests/) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
|
||||
All backend tests are in the [`tests/`](https://github.com/invoke-ai/InvokeAI/tree/main/tests) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
|
||||
|
||||
TODO: The above statement is aspirational. A re-organization of legacy tests is required to make it true.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## **What do I need to know to help?**
|
||||
|
||||
If you are looking to help with a code contribution, InvokeAI uses several different technologies under the hood: Python (Pydantic, FastAPI, diffusers) and Typescript (React, Redux Toolkit, ChakraUI, Mantine, Konva). Familiarity with StableDiffusion and image generation concepts is helpful, but not essential.
|
||||
If you are looking to help with a code contribution, InvokeAI uses several different technologies under the hood: Python (Pydantic, FastAPI, diffusers) and Typescript (React, Redux Toolkit, ChakraUI, Mantine, Konva). Familiarity with StableDiffusion and image generation concepts is helpful, but not essential.
|
||||
|
||||
|
||||
## **Get Started**
|
||||
@@ -12,7 +12,7 @@ To get started, take a look at our [new contributors checklist](newContributorCh
|
||||
Once you're setup, for more information, you can review the documentation specific to your area of interest:
|
||||
|
||||
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
|
||||
* #### [Frontend Documentation](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web)
|
||||
* #### [Frontend Documentation](../frontend/index.md)
|
||||
* #### [Node Documentation](../INVOCATIONS.md)
|
||||
* #### [Local Development](../LOCAL_DEVELOPMENT.md)
|
||||
|
||||
@@ -20,15 +20,15 @@ Once you're setup, for more information, you can review the documentation specif
|
||||
|
||||
If you don't feel ready to make a code contribution yet, no problem! You can also help out in other ways, such as [documentation](documentation.md), [translation](translation.md) or helping support other users and triage issues as they're reported in GitHub.
|
||||
|
||||
There are two paths to making a development contribution:
|
||||
There are two paths to making a development contribution:
|
||||
|
||||
1. Choosing an open issue to address. Open issues can be found in the [Issues](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen) section of the InvokeAI repository. These are tagged by the issue type (bug, enhancement, etc.) along with the “good first issues” tag denoting if they are suitable for first time contributors.
|
||||
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||
2. Opening a new issue or feature to add. **Please make sure you have searched through existing issues before creating new ones.**
|
||||
|
||||
*Regardless of what you choose, please post in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord before you start development in order to confirm that the issue or feature is aligned with the current direction of the project. We value our contributors time and effort and want to ensure that no one’s time is being misspent.*
|
||||
|
||||
## Best Practices:
|
||||
## Best Practices:
|
||||
* Keep your pull requests small. Smaller pull requests are more likely to be accepted and merged
|
||||
* Comments! Commenting your code helps reviewers easily understand your contribution
|
||||
* Use Python and Typescript’s typing systems, and consider using an editor with [LSP](https://microsoft.github.io/language-server-protocol/) support to streamline development
|
||||
@@ -38,7 +38,7 @@ There are two paths to making a development contribution:
|
||||
|
||||
If you need help, you can ask questions in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord.
|
||||
|
||||
For frontend related work, **@psychedelicious** is the best person to reach out to.
|
||||
For frontend related work, **@psychedelicious** is the best person to reach out to.
|
||||
|
||||
For backend related work, please reach out to **@blessedcoolant**, **@lstein**, **@StAlKeR7779** or **@psychedelicious**.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ If you're a new contributor to InvokeAI or Open Source Projects, this is the gui
|
||||
## New Contributor Checklist
|
||||
|
||||
- [x] Set up your local development environment & fork of InvokAI by following [the steps outlined here](../dev-environment.md)
|
||||
- [x] Set up your local tooling with [this guide](InvokeAI/contributing/LOCAL_DEVELOPMENT/#developing-invokeai-in-vscode). Feel free to skip this step if you already have tooling you're comfortable with.
|
||||
- [x] Set up your local tooling with [this guide](../LOCAL_DEVELOPMENT.md). Feel free to skip this step if you already have tooling you're comfortable with.
|
||||
- [x] Familiarize yourself with [Git](https://www.atlassian.com/git) & our project structure by reading through the [development documentation](development.md)
|
||||
- [x] Join the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord
|
||||
- [x] Choose an issue to work on! This can be achieved by asking in the #dev-chat channel, tackling a [good first issue](https://github.com/invoke-ai/InvokeAI/contribute) or finding an item on the [roadmap](https://github.com/orgs/invoke-ai/projects/7). If nothing in any of those places catches your eye, feel free to work on something of interest to you!
|
||||
@@ -22,15 +22,15 @@ Before starting these steps, ensure you have your local environment [configured
|
||||
2. Fork the [InvokeAI](https://github.com/invoke-ai/InvokeAI) repository to your GitHub profile. This means that you will have a copy of the repository under **your-GitHub-username/InvokeAI**.
|
||||
3. Clone the repository to your local machine using:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-GitHub-username/InvokeAI.git
|
||||
```
|
||||
```bash
|
||||
git clone https://github.com/your-GitHub-username/InvokeAI.git
|
||||
```
|
||||
|
||||
If you're unfamiliar with using Git through the commandline, [GitHub Desktop](https://desktop.github.com) is a easy-to-use alternative with a UI. You can do all the same steps listed here, but through the interface. 4. Create a new branch for your fix using:
|
||||
|
||||
```bash
|
||||
git checkout -b branch-name-here
|
||||
```
|
||||
```bash
|
||||
git checkout -b branch-name-here
|
||||
```
|
||||
|
||||
5. Make the appropriate changes for the issue you are trying to address or the feature that you want to add.
|
||||
6. Add the file contents of the changed files to the "snapshot" git uses to manage the state of the project, also known as the index:
|
||||
|
||||
@@ -27,9 +27,9 @@ If you just want to use Invoke, you should use the [installer][installer link].
|
||||
|
||||
5. Activate the venv (you'll need to do this every time you want to run the app):
|
||||
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
6. Install the repo as an [editable install][editable install link]:
|
||||
|
||||
@@ -37,7 +37,7 @@ If you just want to use Invoke, you should use the [installer][installer link].
|
||||
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
Refer to the [manual installation][manual install link]] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
Refer to the [manual installation][manual install link] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
|
||||
@@ -34,11 +34,11 @@ Please reach out to @hipsterusername on [Discord](https://discord.gg/ZmtBAhwWhy)
|
||||
|
||||
## Contributors
|
||||
|
||||
This project is a combined effort of dedicated people from across the world. [Check out the list of all these amazing people](https://invoke-ai.github.io/InvokeAI/other/CONTRIBUTORS/). We thank them for their time, hard work and effort.
|
||||
This project is a combined effort of dedicated people from across the world. [Check out the list of all these amazing people](contributors.md). We thank them for their time, hard work and effort.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](../CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
|
||||
By making a contribution to this project, you certify that:
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ checkpoint models.
|
||||
|
||||
To solve this, go to the Model Manager tab (the cube), select the
|
||||
checkpoint model that's giving you trouble, and press the "Convert"
|
||||
button in the upper right of your browser window. This will conver the
|
||||
button in the upper right of your browser window. This will convert the
|
||||
checkpoint into a diffusers model, after which loading should be
|
||||
faster and less memory-intensive.
|
||||
|
||||
|
||||
@@ -97,16 +97,16 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
sudo pacman -S --needed base-devel
|
||||
```
|
||||
|
||||
2. Install `opencv` and `blas`:
|
||||
2. Install `opencv`, `blas`, and required dependencies:
|
||||
|
||||
```sh
|
||||
sudo pacman -S opencv blas
|
||||
sudo pacman -S opencv blas fmt glew vtk hdf5
|
||||
```
|
||||
|
||||
or for CUDA support
|
||||
|
||||
```sh
|
||||
sudo pacman -S opencv-cuda blas
|
||||
sudo pacman -S opencv-cuda blas fmt glew vtk hdf5
|
||||
```
|
||||
|
||||
3. Fix the naming of the `opencv` package configuration file:
|
||||
|
||||
@@ -99,7 +99,6 @@ their descriptions.
|
||||
| Scale Latents | Scales latents by a given factor. |
|
||||
| Segment Anything Processor | Applies segment anything processing to image |
|
||||
| Show Image | Displays a provided image, and passes it forward in the pipeline. |
|
||||
| Step Param Easing | Experimental per-step parameter easing for denoising steps |
|
||||
| String Primitive Collection | A collection of string primitive values |
|
||||
| String Primitive | A string primitive value |
|
||||
| Subtract Integers | Subtracts two numbers |
|
||||
|
||||
@@ -259,7 +259,7 @@ def select_gpu() -> GpuType:
|
||||
[
|
||||
f"Detected the [gold1]{OS}-{ARCH}[/] platform",
|
||||
"",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/#system[/] to ensure your system meets the minimum requirements.",
|
||||
"See [deep_sky_blue1]https://invoke-ai.github.io/InvokeAI/installation/requirements/[/] to ensure your system meets the minimum requirements.",
|
||||
"",
|
||||
"[red3]🠶[/] [b]Your GPU drivers must be correctly installed before using InvokeAI![/] [red3]🠴[/]",
|
||||
]
|
||||
|
||||
@@ -68,7 +68,7 @@ do_line_input() {
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
|
||||
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest\n\n"
|
||||
read -p "Please enter 1-4, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
do_choice $choice
|
||||
|
||||
@@ -40,6 +40,8 @@ class AppVersion(BaseModel):
|
||||
|
||||
version: str = Field(description="App version")
|
||||
|
||||
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
@@ -10,6 +11,7 @@ from enum import Enum
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import huggingface_hub
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.routing import APIRouter
|
||||
@@ -27,6 +29,7 @@ from invokeai.app.services.model_records import (
|
||||
ModelRecordChanges,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -923,3 +926,51 @@ async def get_stats() -> Optional[CacheStats]:
|
||||
"""Return performance statistics on the model manager's RAM cache. Will return null if no models have been loaded."""
|
||||
|
||||
return ApiDependencies.invoker.services.model_manager.load.ram_cache.stats
|
||||
|
||||
|
||||
class HFTokenStatus(str, Enum):
|
||||
VALID = "valid"
|
||||
INVALID = "invalid"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def set_token(cls, token: str) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
@@ -110,7 +110,7 @@ async def cancel_by_batch_ids(
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
||||
@@ -15,6 +15,11 @@ custom_nodes_readme_path = str(custom_nodes_path / "README.md")
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
|
||||
|
||||
# set the same permissions as the destination directory, in case our source is read-only,
|
||||
# so that the files are user-writable
|
||||
for p in custom_nodes_path.glob("**/*"):
|
||||
p.chmod(custom_nodes_path.stat().st_mode)
|
||||
|
||||
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
|
||||
if spec is None or spec.loader is None:
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -62,6 +63,7 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
|
||||
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
|
||||
- `Internal`: The invocation is not intended for use by end-users. It may be changed or removed at any time, but is exposed for users to play with.
|
||||
- `Special`: The invocation is a special case and does not fit into any of the other classifications.
|
||||
"""
|
||||
|
||||
Stable = "stable"
|
||||
@@ -69,6 +71,7 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Prototype = "prototype"
|
||||
Deprecated = "deprecated"
|
||||
Internal = "internal"
|
||||
Special = "special"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
@@ -192,12 +195,19 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def invalidate_typeadapter(cls) -> None:
|
||||
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
|
||||
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
|
||||
the updated allowlist and denylist."""
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
@@ -479,6 +489,26 @@ def invocation(
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
# Validate the `invoke()` method is implemented
|
||||
if "invoke" in cls.__abstractmethods__:
|
||||
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
|
||||
|
||||
# And validate that `invoke()` returns a subclass of `BaseInvocationOutput
|
||||
invoke_return_annotation = signature(cls.invoke).return_annotation
|
||||
|
||||
try:
|
||||
# TODO(psyche): If `invoke()` is not defined, `return_annotation` ends up as the string "BaseInvocationOutput"
|
||||
# instead of the class `BaseInvocationOutput`. This may be a pydantic bug: https://github.com/pydantic/pydantic/issues/7978
|
||||
if isinstance(invoke_return_annotation, str):
|
||||
invoke_return_annotation = getattr(sys.modules[cls.__module__], invoke_return_annotation)
|
||||
|
||||
assert invoke_return_annotation is not BaseInvocationOutput
|
||||
assert issubclass(invoke_return_annotation, BaseInvocationOutput)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invocation "{invocation_type}" must have a return annotation of a subclass of BaseInvocationOutput (got "{invoke_return_annotation}")'
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
cls.__qualname__,
|
||||
|
||||
@@ -1,98 +1,120 @@
|
||||
from typing import Any, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, LatentsField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def slerp(
|
||||
t: Union[float, np.ndarray],
|
||||
v0: Union[torch.Tensor, np.ndarray],
|
||||
v1: Union[torch.Tensor, np.ndarray],
|
||||
device: torch.device,
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
|
||||
|
||||
@invocation(
|
||||
"lblend",
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
tags=["latents", "blend", "mask"],
|
||||
category="latents",
|
||||
version="1.0.3",
|
||||
version="1.1.0",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
"""Blend two latents using a given alpha. If a mask is provided, the second latents will be masked before blending.
|
||||
Latents must have same size. Masking functionality added by @dwringer."""
|
||||
|
||||
latents_a: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents_b: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
latents_a: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
latents_b: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask for blending in latents B")
|
||||
alpha: float = InputField(ge=0, default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
return mask_tensor
|
||||
|
||||
def replace_tensor_from_masked_tensor(
|
||||
self, tensor: torch.Tensor, other_tensor: torch.Tensor, mask_tensor: torch.Tensor
|
||||
):
|
||||
output = tensor.clone()
|
||||
mask_tensor = mask_tensor.expand(output.shape)
|
||||
if output.dtype != torch.float16:
|
||||
output = torch.add(output, mask_tensor * torch.sub(other_tensor, tensor))
|
||||
else:
|
||||
output = torch.add(output, mask_tensor.half() * torch.sub(other_tensor, tensor))
|
||||
return output
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||
if self.mask is None:
|
||||
mask_tensor = torch.zeros(latents_a.shape[-2:])
|
||||
else:
|
||||
mask_tensor = self.prep_mask_tensor(context.images.get_pil(self.mask.image_name))
|
||||
mask_tensor = tv_resize(mask_tensor, latents_a.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
|
||||
latents_b = self.replace_tensor_from_masked_tensor(latents_b, latents_a, mask_tensor)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
raise ValueError("Latents to blend must be the same size.")
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b, device)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
@@ -82,10 +82,11 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -95,6 +96,7 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_manager,
|
||||
),
|
||||
):
|
||||
context.util.signal_progress("Building conditioning")
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
compel = Compel(
|
||||
@@ -178,10 +180,11 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
@@ -191,6 +194,7 @@ class SDXLPromptInvocationBase:
|
||||
ti_manager,
|
||||
),
|
||||
):
|
||||
context.util.signal_progress("Building conditioning")
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
|
||||
|
||||
1563
invokeai/app/invocations/composition-nodes.py
Normal file
1563
invokeai/app/invocations/composition-nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -65,6 +65,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
|
||||
@@ -131,6 +131,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
@@ -510,6 +511,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
bgr_mode: bool = False,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
@@ -519,6 +521,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
if bgr_mode: # SDXL t2i trained on cv2's BGR outputs, but PIL won't convert straight to BGR
|
||||
r, g, b = image.split()
|
||||
image = Image.merge("RGB", (b, g, r))
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
@@ -616,13 +622,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
|
||||
# SDXL adapters are trained on cv2's BGR outputs
|
||||
r, g, b = image.split()
|
||||
image = Image.merge("RGB", (b, g, r))
|
||||
else:
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
||||
|
||||
@@ -630,29 +640,39 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
with t2i_adapter_loaded_model as t2i_adapter_model:
|
||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||
# T2I-Adapter model.
|
||||
#
|
||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||
# of the same requirements (e.g. preserving binary masks during resize).
|
||||
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
_, _, latent_height, latent_width = latents_shape
|
||||
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||
t2i_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=t2i_input_width,
|
||||
height=t2i_input_height,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
# We crop the image to this size so that the positions match the input image on non-standard resolutions
|
||||
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
||||
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
||||
if t2i_image.shape[2] > t2i_input_height or t2i_image.shape[3] > t2i_input_width:
|
||||
t2i_image = t2i_image[
|
||||
:, :, : min(t2i_image.shape[2], t2i_input_height), : min(t2i_image.shape[3], t2i_input_width)
|
||||
]
|
||||
|
||||
adapter_state = t2i_adapter_model(t2i_image)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
@@ -900,7 +920,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
bgr_mode = self.unet.unet.base == BaseModelType.StableDiffusionXL
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager, bgr_mode)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
@@ -982,10 +1003,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
dtype=unet.dtype,
|
||||
cached_weights=cached_weights,
|
||||
),
|
||||
):
|
||||
|
||||
@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
FluxMainModel = "FluxMainModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
@@ -52,7 +53,10 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
T5EncoderModel = "T5EncoderModelField"
|
||||
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
StructuralLoRAModel = "StructuralLoRAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -131,15 +135,19 @@ class FieldDescriptions:
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
clip_g_model = "CLIP-G Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
mmditx = "MMDiTX"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
structural_lora_model = "Structural LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -244,6 +252,17 @@ class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
from typing import Callable, Iterator, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -8,6 +8,8 @@ import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
@@ -22,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -30,6 +32,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
@@ -42,6 +45,9 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
@@ -56,7 +62,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.0",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -81,15 +87,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: FluxConditioningField | None = InputField(
|
||||
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
@@ -138,36 +145,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
return t5_embeddings, clip_embeddings
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
neg_t5_embeddings: torch.Tensor | None = None
|
||||
neg_clip_embeddings: torch.Tensor | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
@@ -182,15 +165,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.positive_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.negative_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
image_seq_len=packed_h * packed_w,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
@@ -207,9 +220,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"to be poor. Consider using a FLUX dev model instead."
|
||||
)
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
if self.add_noise:
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
x = init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
@@ -224,28 +240,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
|
||||
pos_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
neg_txt_ids: torch.Tensor | None = None
|
||||
if neg_t5_embeddings is not None:
|
||||
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
|
||||
neg_txt_ids = torch.zeros(
|
||||
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
|
||||
# packed_w correctly.
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
@@ -283,6 +288,16 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
img_cond = None
|
||||
if struct_lora := self.transformer.structural_lora:
|
||||
# What should we do when we have multiple of these?
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a strutural lora")
|
||||
ae_info = context.models.load(self.controlnet_vae.vae)
|
||||
img = context.images.get_pil(struct_lora.img.image_name)
|
||||
with ae_info as ae:
|
||||
assert isinstance(ae, AutoEncoder)
|
||||
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
|
||||
|
||||
# Load the transformer model.
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
@@ -295,10 +310,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -310,7 +326,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
LoRAPatcher.apply_lora_wrapper_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
@@ -334,12 +350,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=pos_t5_embeddings,
|
||||
txt_ids=pos_txt_ids,
|
||||
vec=pos_clip_embeddings,
|
||||
neg_txt=neg_t5_embeddings,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
neg_vec=neg_clip_embeddings,
|
||||
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
||||
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
@@ -348,11 +360,49 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: FluxConditioningField | list[FluxConditioningField],
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[FluxTextConditioning]:
|
||||
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
||||
# Normalize to a list of FluxConditioningFields.
|
||||
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[FluxTextConditioning] = []
|
||||
for cond_field in cond_list:
|
||||
# Load the text embeddings.
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if cond_field.mask is not None:
|
||||
mask = context.tensors.load(cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
@@ -648,7 +698,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
|
||||
if self.transformer.structural_lora:
|
||||
loras.append(self.transformer.structural_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
|
||||
89
invokeai/app/invocations/flux_model_loader.py
Normal file
89
invokeai/app/invocations/flux_model_loader.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
70
invokeai/app/invocations/flux_structural_lora_loader.py
Normal file
70
invokeai/app/invocations/flux_structural_lora_loader.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Optional, Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_structural_lora_loader_output")
|
||||
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Flux Structural LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_structural_lora_loader",
|
||||
title="Flux Structural LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
|
||||
)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
|
||||
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
output = FluxStructuralLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.structural_lora = StructuralLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1,11 +1,18 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -15,6 +22,7 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -22,7 +30,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -41,7 +49,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
@@ -54,7 +65,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
@@ -71,6 +84,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
@@ -98,10 +112,11 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -111,6 +126,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
|
||||
@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
@@ -53,6 +54,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
context.util.signal_progress("Running VAE")
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
@@ -44,9 +44,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@@ -60,6 +59,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
59
invokeai/app/invocations/image_panels.py
Normal file
59
invokeai/app/invocations/image_panels.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("image_panel_coordinate_output")
|
||||
class ImagePanelCoordinateOutput(BaseInvocationOutput):
|
||||
x_left: int = OutputField(description="The left x-coordinate of the panel.")
|
||||
y_top: int = OutputField(description="The top y-coordinate of the panel.")
|
||||
width: int = OutputField(description="The width of the panel.")
|
||||
height: int = OutputField(description="The height of the panel.")
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_panel_layout",
|
||||
title="Image Panel Layout",
|
||||
tags=["image", "panel", "layout"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ImagePanelLayoutInvocation(BaseInvocation):
|
||||
"""Get the coordinates of a single panel in a grid. (If the full image shape cannot be divided evenly into panels,
|
||||
then the grid may not cover the entire image.)
|
||||
"""
|
||||
|
||||
width: int = InputField(description="The width of the entire grid.")
|
||||
height: int = InputField(description="The height of the entire grid.")
|
||||
num_cols: int = InputField(ge=1, default=1, description="The number of columns in the grid.")
|
||||
num_rows: int = InputField(ge=1, default=1, description="The number of rows in the grid.")
|
||||
panel_col_idx: int = InputField(ge=0, default=0, description="The column index of the panel to be processed.")
|
||||
panel_row_idx: int = InputField(ge=0, default=0, description="The row index of the panel to be processed.")
|
||||
|
||||
@field_validator("panel_col_idx")
|
||||
def validate_panel_col_idx(cls, v: int, info: ValidationInfo) -> int:
|
||||
if v < 0 or v >= info.data["num_cols"]:
|
||||
raise ValueError(f"panel_col_idx must be between 0 and {info.data['num_cols'] - 1}")
|
||||
return v
|
||||
|
||||
@field_validator("panel_row_idx")
|
||||
def validate_panel_row_idx(cls, v: int, info: ValidationInfo) -> int:
|
||||
if v < 0 or v >= info.data["num_rows"]:
|
||||
raise ValueError(f"panel_row_idx must be between 0 and {info.data['num_rows'] - 1}")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImagePanelCoordinateOutput:
|
||||
x_left = self.panel_col_idx * (self.width // self.num_cols)
|
||||
y_top = self.panel_row_idx * (self.height // self.num_rows)
|
||||
width = self.width // self.num_cols
|
||||
height = self.height // self.num_rows
|
||||
return ImagePanelCoordinateOutput(x_left=x_left, y_top=y_top, width=width, height=height)
|
||||
@@ -117,6 +117,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
@@ -60,6 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE decoder")
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
||||
@@ -147,6 +147,10 @@ GENERATION_MODES = Literal[
|
||||
"flux_img2img",
|
||||
"flux_inpaint",
|
||||
"flux_outpaint",
|
||||
"sd3_txt2img",
|
||||
"sd3_img2img",
|
||||
"sd3_inpaint",
|
||||
"sd3_outpaint",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -10,14 +10,12 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
@@ -67,11 +65,6 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
@@ -81,6 +74,13 @@ class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
class StructuralLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
@@ -139,78 +139,6 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_model_loader",
|
||||
title="Flux Main Model",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
ui_type=UIType.FluxMainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||
if not context.models.exists(key):
|
||||
raise ValueError(f"Unknown model: {key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
|
||||
transformer_config = context.models.get_config(transformer)
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
|
||||
@@ -1,43 +1,4 @@
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from easing_functions import (
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseInOut,
|
||||
CubicEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseOut,
|
||||
LinearInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseInOut,
|
||||
QuadEaseOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseOut,
|
||||
SineEaseIn,
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
@@ -65,191 +26,3 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation(
|
||||
"step_param_easing",
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.2",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("start_step: " + str(start_step))
|
||||
context.logger.debug("end_step: " + str(end_step))
|
||||
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.logger.debug("prelist: " + str(prelist))
|
||||
context.logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
context.logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = []
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1,
|
||||
)
|
||||
base_easing_vals = []
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1,
|
||||
)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
@@ -4,7 +4,13 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
@@ -18,6 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
@@ -426,6 +433,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
|
||||
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
@@ -521,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
338
invokeai/app/invocations/sd3_denoise.py
Normal file
338
invokeai/app/invocations/sd3_denoise.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_denoise",
|
||||
title="SD3 Denoise",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
|
||||
inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
||||
represent the regions to be preserved.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
||||
# 1.0 represents the regions to be preserved.
|
||||
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
joint_attention_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
sd3_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
|
||||
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
|
||||
|
||||
t5_embeds = sd3_conditioning.t5_embeds
|
||||
if t5_embeds is None:
|
||||
t5_embeds = torch.zeros(
|
||||
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat(
|
||||
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
num_samples: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
"""Prepare the CFG scale list.
|
||||
|
||||
Args:
|
||||
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
|
||||
on the scheduler used (e.g. higher order schedulers).
|
||||
|
||||
Returns:
|
||||
list[float]: _description_
|
||||
"""
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
|
||||
return cfg_scale
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Load/process the conditioning data.
|
||||
# TODO(ryand): Make CFG optional.
|
||||
do_classifier_free_guidance = True
|
||||
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# TODO(ryand): Support both sequential and batched CFG inference.
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the timestep schedule.
|
||||
# We add an extra step to the end to account for the final timestep of 0.0.
|
||||
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
assert isinstance(num_channels_latents, int)
|
||||
noise = self._get_noise(
|
||||
num_samples=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# Noise the init_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
latents = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return latents
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
with transformer_info.model_on_device() as (cached_weights, transformer):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=None,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG.
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (t_prev - t_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
|
||||
|
||||
return step_callback
|
||||
65
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
65
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_i2l",
|
||||
title="SD3 Image to Latents",
|
||||
tags=["image", "latents", "vae", "i2l", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
# TODO: Use seed to make sampling reproducible.
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
74
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
74
invokeai/app/invocations/sd3_latents_to_image.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_l2i",
|
||||
title="SD3 Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i", "sd3"],
|
||||
category="latents",
|
||||
version="1.3.0",
|
||||
)
|
||||
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE")
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(vae.device)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# copied from diffusers pipeline
|
||||
latents = latents / vae.config.scaling_factor
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
108
invokeai/app/invocations/sd3_model_loader.py
Normal file
108
invokeai/app/invocations/sd3_model_loader.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SD3 base model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
|
||||
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_model_loader",
|
||||
title="SD3 Main Model",
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
ui_type=UIType.SD3MainModel,
|
||||
input=Input.Direct,
|
||||
)
|
||||
|
||||
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
ui_type=UIType.T5EncoderModel,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_l_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
ui_type=UIType.CLIPLEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP L Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
clip_g_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.clip_g_model,
|
||||
ui_type=UIType.CLIPGEmbedModel,
|
||||
input=Input.Direct,
|
||||
title="CLIP G Encoder",
|
||||
default=None,
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
vae = (
|
||||
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
if self.vae_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
)
|
||||
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
clip_encoder_l = (
|
||||
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
if self.clip_l_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
)
|
||||
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
clip_encoder_g = (
|
||||
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
if self.clip_g_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
)
|
||||
tokenizer_t5 = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
)
|
||||
t5_encoder = (
|
||||
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
if self.t5_encoder_model
|
||||
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
)
|
||||
|
||||
return Sd3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
|
||||
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
203
invokeai/app/invocations/sd3_text_encoder.py
Normal file
203
invokeai/app/invocations/sd3_text_encoder.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
SD3_T5_MAX_SEQ_LEN = 256
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_text_encoder",
|
||||
title="SD3 Text Encoding",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
clip_l: CLIPField = InputField(
|
||||
title="CLIP L",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
clip_g: CLIPField = InputField(
|
||||
title="CLIP G",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
|
||||
t5_encoder: T5EncoderField | None = InputField(
|
||||
title="T5Encoder",
|
||||
default=None,
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
|
||||
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
|
||||
|
||||
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
|
||||
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
|
||||
|
||||
t5_embeddings: torch.Tensor | None = None
|
||||
if self.t5_encoder is not None:
|
||||
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SD3ConditioningInfo(
|
||||
clip_l_embeds=clip_l_embeddings,
|
||||
clip_l_pooled_embeds=clip_l_pooled_embeddings,
|
||||
clip_g_embeds=clip_g_embeddings,
|
||||
clip_g_pooled_embeds=clip_g_pooled_embeddings,
|
||||
t5_embeds=t5_embeddings,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return SD3ConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
assert self.t5_encoder is not None
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
text_inputs = t5_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(
|
||||
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
clip_tokenizer_info = context.models.load(clip_model.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(clip_model.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
|
||||
# Apply LoRA models to the CLIP encoder.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
dtype=TorchDevice.choose_torch_dtype(),
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# There are currently no supported CLIP quantized models. Add support here if needed.
|
||||
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
|
||||
|
||||
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
|
||||
|
||||
text_inputs = clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
assert isinstance(text_input_ids, torch.Tensor)
|
||||
assert isinstance(untruncated_ids, torch.Tensor)
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
|
||||
context.logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
from transformers.models.sam import SamModel
|
||||
from transformers.models.sam.processing_sam import SamProcessor
|
||||
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_point_lists_or_bounding_box(self):
|
||||
if self.point_lists is None and self.bounding_boxes is None:
|
||||
raise ValueError("Either point_lists or bounding_box must be provided.")
|
||||
elif self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
if self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
|
||||
@@ -207,7 +207,9 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
LoRAPatcher.apply_smart_lora_patches(
|
||||
model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype
|
||||
),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import filecmp
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
@@ -525,9 +526,35 @@ def get_config() -> InvokeAIAppConfig:
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
# Copy all legacy configs only if needed
|
||||
# We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
dest_path = config.legacy_conf_path
|
||||
|
||||
# Create destination (we don't need to check for existence)
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Compare directories recursively
|
||||
comparison = filecmp.dircmp(configs_src, dest_path)
|
||||
need_copy = any(
|
||||
[
|
||||
comparison.left_only, # Files exist only in source
|
||||
comparison.diff_files, # Files that differ
|
||||
comparison.common_funny, # Files that couldn't be compared
|
||||
]
|
||||
)
|
||||
|
||||
if need_copy:
|
||||
# Get permissions from destination directory
|
||||
dest_mode = dest_path.stat().st_mode
|
||||
|
||||
# Copy directory tree
|
||||
shutil.copytree(configs_src, dest_path, dirs_exist_ok=True)
|
||||
|
||||
# Set permissions on copied files to match destination directory
|
||||
dest_path.chmod(dest_mode)
|
||||
for p in dest_path.glob("**/*"):
|
||||
p.chmod(dest_mode)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
|
||||
@@ -86,7 +86,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
|
||||
@@ -378,6 +378,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
# Make sure the cancel event is for the currently processing queue item
|
||||
if self._queue_item and self._queue_item.item_id != event[1].item_id:
|
||||
return
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
|
||||
@@ -16,6 +16,7 @@ from pydantic import (
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
BatchDataType = Union[StrictStr, float, int, ImageField]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
|
||||
@@ -160,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ImagesInterface(InvocationContextInterface):
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: Image,
|
||||
@@ -186,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
self._util.signal_progress("Saving image")
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
@@ -336,6 +342,10 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Check if a model exists.
|
||||
|
||||
@@ -368,11 +378,15 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
|
||||
message = f"Loading model {model.name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@@ -397,6 +411,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
message = f"Loading model {name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
@@ -467,6 +485,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
self._util.signal_progress(f"Downloading model {source}")
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_local_model(
|
||||
@@ -489,6 +508,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
self._util.signal_progress(f"Loading model {model_path.name}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
@@ -514,6 +535,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
|
||||
self._util.signal_progress(f"Loading model {source}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
@@ -707,12 +730,12 @@ def build_invocation_context(
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
images = ImagesInterface(services=services, data=data)
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data, util=util)
|
||||
images = ImagesInterface(services=services, data=data, util=util)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
|
||||
@@ -35,7 +35,7 @@ class Migration11Callback:
|
||||
|
||||
def _remove_convert_cache(self) -> None:
|
||||
"""Rename models/.cache to models/.convert_cache."""
|
||||
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||
self._logger.info("Removing models/.cache directory. Converted models will now be cached in .convert_cache.")
|
||||
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
{
|
||||
"name": "SD3.5 Text to Image",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample text to image workflow for Stable Diffusion 3.5",
|
||||
"version": "1.0.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, SD3.5, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"fieldName": "prompt"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "3.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"type": "sd3_model_loader",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"name": "model",
|
||||
"label": "",
|
||||
"value": {
|
||||
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
|
||||
"hash": "placeholder",
|
||||
"name": "stable-diffusion-3.5-medium",
|
||||
"base": "sd-3",
|
||||
"type": "main"
|
||||
}
|
||||
},
|
||||
"t5_encoder_model": {
|
||||
"name": "t5_encoder_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_l_model": {
|
||||
"name": "clip_l_model",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g_model": {
|
||||
"name": "clip_g_model",
|
||||
"label": ""
|
||||
},
|
||||
"vae_model": {
|
||||
"name": "vae_model",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": -55.58689609637031,
|
||||
"y": -111.53602444662268
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"type": "rand_int",
|
||||
"version": "1.0.1",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"name": "low",
|
||||
"label": "",
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"name": "high",
|
||||
"label": "",
|
||||
"value": 2147483647
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 470.45870147220353,
|
||||
"y": 350.3141781644303
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"type": "sd3_l2i",
|
||||
"version": "1.3.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"latents": {
|
||||
"name": "latents",
|
||||
"label": ""
|
||||
},
|
||||
"vae": {
|
||||
"name": "vae",
|
||||
"label": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 1192.3097009334897,
|
||||
"y": -366.0994675072209
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 408.16054647924784,
|
||||
"y": 65.06415352118786
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"type": "sd3_text_encoder",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"clip_l": {
|
||||
"name": "clip_l",
|
||||
"label": ""
|
||||
},
|
||||
"clip_g": {
|
||||
"name": "clip_g",
|
||||
"label": ""
|
||||
},
|
||||
"t5_encoder": {
|
||||
"name": "t5_encoder",
|
||||
"label": ""
|
||||
},
|
||||
"prompt": {
|
||||
"name": "prompt",
|
||||
"label": "",
|
||||
"value": ""
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 378.9283412440941,
|
||||
"y": -302.65777497352553
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"type": "sd3_denoise",
|
||||
"version": "1.0.0",
|
||||
"label": "",
|
||||
"notes": "",
|
||||
"isOpen": true,
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"board": {
|
||||
"name": "board",
|
||||
"label": ""
|
||||
},
|
||||
"metadata": {
|
||||
"name": "metadata",
|
||||
"label": ""
|
||||
},
|
||||
"transformer": {
|
||||
"name": "transformer",
|
||||
"label": ""
|
||||
},
|
||||
"positive_conditioning": {
|
||||
"name": "positive_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"name": "negative_conditioning",
|
||||
"label": ""
|
||||
},
|
||||
"cfg_scale": {
|
||||
"name": "cfg_scale",
|
||||
"label": "",
|
||||
"value": 3.5
|
||||
},
|
||||
"width": {
|
||||
"name": "width",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"height": {
|
||||
"name": "height",
|
||||
"label": "",
|
||||
"value": 1024
|
||||
},
|
||||
"steps": {
|
||||
"name": "steps",
|
||||
"label": "",
|
||||
"value": 30
|
||||
},
|
||||
"seed": {
|
||||
"name": "seed",
|
||||
"label": "",
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"position": {
|
||||
"x": 813.7814762740603,
|
||||
"y": -142.20529727605867
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "t5_encoder",
|
||||
"targetHandle": "t5_encoder"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_g",
|
||||
"targetHandle": "clip_g"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"sourceHandle": "clip_l",
|
||||
"targetHandle": "clip_l"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
|
||||
"type": "default",
|
||||
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "transformer",
|
||||
"targetHandle": "transformer"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
|
||||
"type": "default",
|
||||
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
|
||||
"type": "default",
|
||||
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
|
||||
"type": "default",
|
||||
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
|
||||
"type": "default",
|
||||
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
|
||||
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
SD3_5_LATENT_RGB_FACTORS = [
|
||||
[-0.05240681, 0.03251581, 0.0749016],
|
||||
[-0.0580572, 0.00759826, 0.05729818],
|
||||
[0.16144888, 0.01270368, -0.03768577],
|
||||
[0.14418615, 0.08460266, 0.15941818],
|
||||
[0.04894035, 0.0056485, -0.06686988],
|
||||
[0.05187166, 0.19222395, 0.06261094],
|
||||
[0.1539433, 0.04818359, 0.07103094],
|
||||
[-0.08601796, 0.09013458, 0.10893912],
|
||||
[-0.12398469, -0.06766567, 0.0033688],
|
||||
[-0.0439737, 0.07825329, 0.02258823],
|
||||
[0.03101129, 0.06382551, 0.07753657],
|
||||
[-0.01315361, 0.08554491, -0.08772475],
|
||||
[0.06464487, 0.05914605, 0.13262741],
|
||||
[-0.07863674, -0.02261737, -0.12761454],
|
||||
[-0.09923835, -0.08010759, -0.06264447],
|
||||
[-0.03392309, -0.0804029, -0.06078822],
|
||||
]
|
||||
|
||||
FLUX_LATENT_RGB_FACTORS = [
|
||||
[-0.0412, 0.0149, 0.0521],
|
||||
[0.0056, 0.0291, 0.0768],
|
||||
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
elif base_model == BaseModelType.StableDiffusion3:
|
||||
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
|
||||
class CustomDoubleStreamBlockProcessor:
|
||||
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
|
||||
|
||||
@staticmethod
|
||||
def _double_stream_block_forward(
|
||||
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
|
||||
block: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
||||
values.
|
||||
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
||||
block, img, txt, vec, pe, attn_mask=attn_mask
|
||||
)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class CustomSingleStreamBlockProcessor:
|
||||
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
|
||||
etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _single_stream_block_forward(
|
||||
block: SingleStreamBlock,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function is a direct copy of SingleStreamBlock.forward()."""
|
||||
mod, _ = block.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
q, k = block.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
@staticmethod
|
||||
def custom_single_block_forward(
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: SingleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> torch.Tensor:
|
||||
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
||||
- Masking
|
||||
"""
|
||||
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
||||
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -18,14 +19,8 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
pos_regional_prompting_extension: RegionalPromptingExtension,
|
||||
neg_regional_prompting_extension: RegionalPromptingExtension | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
@@ -35,6 +30,8 @@ def denoise(
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
# extra img tokens
|
||||
img_cond: torch.Tensor | None = None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -61,9 +58,9 @@ def denoise(
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
@@ -74,13 +71,13 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
pred = model(
|
||||
img=img,
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -88,6 +85,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
@@ -97,15 +95,15 @@ def denoise(
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -113,6 +111,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
|
||||
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
class RegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with FLUX.
|
||||
|
||||
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
restricted_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.restricted_attn_mask = restricted_attn_mask
|
||||
|
||||
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
|
||||
"""Create a RegionalPromptingExtension from a list of text conditionings.
|
||||
|
||||
Args:
|
||||
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
|
||||
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
|
||||
"""
|
||||
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
|
||||
regional_text_conditioning, img_seq_len
|
||||
)
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
|
||||
)
|
||||
|
||||
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
|
||||
#
|
||||
# @classmethod
|
||||
# def _prepare_unrestricted_attn_mask(
|
||||
# cls,
|
||||
# regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
# img_seq_len: int,
|
||||
# ) -> torch.Tensor:
|
||||
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
|
||||
# - img self-attention is not masked.
|
||||
# - img regions attend to both txt within their own region and to global prompts.
|
||||
# """
|
||||
# device = TorchDevice.choose_torch_device()
|
||||
|
||||
# # Infer txt_seq_len from the t5_embeddings tensor.
|
||||
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# # Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# # There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# # 1. txt attends to itself
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # 4. regional img attends to itself
|
||||
|
||||
# # Initialize empty attention mask.
|
||||
# regional_attention_mask = torch.zeros(
|
||||
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
# )
|
||||
|
||||
# for image_mask, t5_embedding_range in zip(
|
||||
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
# ):
|
||||
# # 1. txt attends to itself
|
||||
# regional_attention_mask[
|
||||
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
# ] = 1.0
|
||||
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
|
||||
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
|
||||
|
||||
# # 4. regional img attends to itself
|
||||
# # Allow unrestricted img self attention.
|
||||
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
|
||||
# # Convert attention mask to boolean.
|
||||
# regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
# return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _prepare_restricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
|
||||
- img self-attention is only allowed within regions.
|
||||
- img regions only attend to txt within their own region, not to global prompts.
|
||||
"""
|
||||
# Identify background region. I.e. the region that is not covered by any region masks.
|
||||
background_region_mask: None | torch.Tensor = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(image_mask)
|
||||
background_region_mask *= 1 - image_mask
|
||||
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, short-circuit and return None.
|
||||
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
|
||||
# is a rare use case and would make the logic here significantly more complicated.
|
||||
return None
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
for image_mask, t5_embedding_range in zip(
|
||||
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
):
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is not None:
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
image_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
image_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
else:
|
||||
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
|
||||
# that regions focus on their local prompts. We do, however, allow attention between background regions
|
||||
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
|
||||
# embeddings, which we found experimentally to cause artifacts.
|
||||
|
||||
# 2. global txt attends to background region
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
background_region_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. background region attends to global txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
background_region_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# Allow background regions to attend to themselves.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_regional_text_conditioning(
|
||||
cls,
|
||||
text_conditionings: list[FluxTextConditioning],
|
||||
) -> FluxRegionalTextConditioning:
|
||||
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
|
||||
concat_t5_embeddings: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
|
||||
# Choose global CLIP embedding.
|
||||
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
|
||||
# the first prompt's CLIP embedding.
|
||||
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
|
||||
for text_conditioning in text_conditionings:
|
||||
if text_conditioning.mask is None:
|
||||
global_clip_embedding = text_conditioning.clip_embeddings
|
||||
break
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
|
||||
concat_t5_embedding_ranges.append(
|
||||
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
|
||||
)
|
||||
|
||||
image_masks.append(text_conditioning.mask)
|
||||
|
||||
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
|
||||
|
||||
# Initialize the txt_ids tensor.
|
||||
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
|
||||
t5_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
clip_embeddings=global_clip_embedding,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=image_masks,
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
|
||||
# Flatten the height and width dimensions into a single image_seq_len dimension.
|
||||
return resized_mask.flatten(start_dim=2)
|
||||
27
invokeai/backend/flux/flux_tools_sampling_utils.py
Normal file
27
invokeai/backend/flux/flux_tools_sampling_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
def prepare_control(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
ae: AutoEncoder,
|
||||
cond_image: Image.Image,
|
||||
) -> torch.Tensor:
|
||||
# load and encode the conditioning image
|
||||
img_cond = cond_image.convert("RGB")
|
||||
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float()
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
ae_dtype = next(iter(ae.parameters())).dtype
|
||||
ae_device = next(iter(ae.parameters())).device
|
||||
img_cond = img_cond.to(device=ae_device, dtype=ae_dtype)
|
||||
generator = torch.Generator(device=ae_device).manual_seed(seed)
|
||||
img_cond = ae.encode(img_cond, sample=True, generator=generator)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
return img_cond
|
||||
@@ -41,10 +41,12 @@ def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Te
|
||||
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
|
||||
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
|
||||
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
|
||||
clip_extra_context_tokens = state_dict["ip_adapter_proj_model.proj.weight"].shape[0] // context_dim
|
||||
|
||||
return XlabsIpAdapterParams(
|
||||
num_double_blocks=num_double_blocks,
|
||||
context_dim=context_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
clip_embeddings_dim=clip_embeddings_dim,
|
||||
clip_extra_context_tokens=clip_extra_context_tokens,
|
||||
)
|
||||
|
||||
@@ -31,13 +31,16 @@ class XlabsIpAdapterParams:
|
||||
hidden_dim: int
|
||||
|
||||
clip_embeddings_dim: int
|
||||
clip_extra_context_tokens: int
|
||||
|
||||
|
||||
class XlabsIpAdapterFlux(torch.nn.Module):
|
||||
def __init__(self, params: XlabsIpAdapterParams):
|
||||
super().__init__()
|
||||
self.image_proj = ImageProjModel(
|
||||
cross_attention_dim=params.context_dim, clip_embeddings_dim=params.clip_embeddings_dim
|
||||
cross_attention_dim=params.context_dim,
|
||||
clip_embeddings_dim=params.clip_embeddings_dim,
|
||||
clip_extra_context_tokens=params.clip_extra_context_tokens,
|
||||
)
|
||||
self.ip_adapter_double_blocks = IPAdapterDoubleBlocks(
|
||||
num_double_blocks=params.num_double_blocks, context_dim=params.context_dim, hidden_dim=params.hidden_dim
|
||||
|
||||
@@ -5,10 +5,10 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -24,12 +24,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.view(*xq.shape).type_as(xq), xk_out.view(*xk.shape).type_as(xk)
|
||||
|
||||
@@ -4,8 +4,13 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
CustomSingleStreamBlockProcessor,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -31,6 +36,7 @@ class FluxParams:
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
out_channels: Optional[int] = None
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@@ -43,7 +49,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
self.out_channels = params.out_channels or self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
@@ -95,6 +101,7 @@ class Flux(nn.Module):
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -117,7 +124,6 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
assert isinstance(block, DoubleStreamBlock)
|
||||
|
||||
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
@@ -128,6 +134,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
ip_adapter_extensions=ip_adapter_extensions,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
@@ -140,7 +147,17 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
assert isinstance(block, SingleStreamBlock)
|
||||
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img=img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
50
invokeai/backend/flux/modules/image_embedders.py
Normal file
50
invokeai/backend/flux/modules/image_embedders.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
from torch import nn
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
class DepthImageEncoder:
|
||||
depth_model_name = "LiheYoung/depth-anything-large-hf"
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
hw = img.shape[-2:]
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_byte = ((img + 1.0) * 127.5).byte()
|
||||
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
|
||||
depth = self.depth_model(img.to(self.device)).predicted_depth
|
||||
depth = repeat(depth, "b h w -> b 3 h w")
|
||||
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
|
||||
depth = depth / 127.5 - 1.0
|
||||
return depth
|
||||
|
||||
class CannyImageEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
min_t: int = 50,
|
||||
max_t: int = 200,
|
||||
):
|
||||
self.device = device
|
||||
self.min_t = min_t
|
||||
self.max_t = max_t
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
assert img.shape[0] == 1, "Only batch size 1 is supported"
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
|
||||
# Apply Canny edge detection
|
||||
canny = cv2.Canny(img_np, self.min_t, self.max_t)
|
||||
# Convert back to torch tensor and reshape
|
||||
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
|
||||
canny = rearrange(canny, "h w -> 1 1 h w")
|
||||
canny = repeat(canny, "b 1 ... -> b 3 ...")
|
||||
return canny.to(self.device)
|
||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
36
invokeai/backend/flux/text_conditioning.py
Normal file
36
invokeai/backend/flux/text_conditioning.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxTextConditioning:
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
# If mask is None, the prompt is a global prompt.
|
||||
mask: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
# Shape: (1, concatenated_txt_seq_len, 4096)
|
||||
t5_embeddings: torch.Tensor
|
||||
# Shape: (1, concatenated_txt_seq_len, 3)
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# Global CLIP embeddings.
|
||||
# Shape: (1, 768)
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
|
||||
# global prompt.
|
||||
# image_masks[i] is the mask for the ith prompt.
|
||||
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
|
||||
image_masks: list[torch.Tensor | None]
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
BIN
invokeai/backend/image_util/assets/CIELab_to_UPLab.icc
Normal file
BIN
invokeai/backend/image_util/assets/CIELab_to_UPLab.icc
Normal file
Binary file not shown.
1020
invokeai/backend/image_util/composition.py
Normal file
1020
invokeai/backend/image_util/composition.py
Normal file
File diff suppressed because it is too large
Load Diff
65
invokeai/backend/lora/conversions/flux_control_lora_utils.py
Normal file
65
invokeai/backend/lora/conversions/flux_control_lora_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
# guidance_in.in_layer.lora_B.bias
|
||||
# single_blocks.0.linear1.lora_A.weight
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
key_props = key.split(".")
|
||||
# Got it loading using lora_down and lora_up but it didn't seem to match this lora's structure
|
||||
# Leaving this in since it doesn't hurt anything and may be better
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
# Convert to a full layer diff
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"]
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
else:
|
||||
raise AssertionError(f"{layer_key} not expected")
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
@@ -45,8 +45,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
num_single_layers = 38
|
||||
# inner_dim = 3072
|
||||
# mlp_ratio = 4.0
|
||||
hidden_size = 3072
|
||||
mlp_ratio = 4.0
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
|
||||
@@ -62,30 +63,43 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||
def add_qkv_lora_layer_if_present(
|
||||
src_keys: list[str],
|
||||
src_weight_shapes: list[tuple[int, int]],
|
||||
dst_qkv_key: str,
|
||||
allow_missing_keys: bool = False,
|
||||
) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
"""
|
||||
# We expect that either all src keys are present or none of them are. Verify this.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
assert all(keys_present) or not any(keys_present)
|
||||
|
||||
# If none of the keys are present, return early.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
|
||||
src_layer_dict = grouped_state_dict.pop(src_key, None)
|
||||
if src_layer_dict is not None:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
|
||||
values = {
|
||||
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
|
||||
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
|
||||
}
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
@@ -118,6 +132,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
|
||||
f"double_blocks.{i}.img_attn.qkv",
|
||||
)
|
||||
add_qkv_lora_layer_if_present(
|
||||
@@ -126,6 +141,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
|
||||
f"double_blocks.{i}.txt_attn.qkv",
|
||||
)
|
||||
|
||||
@@ -175,7 +191,14 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"single_transformer_blocks.{i}.attn.to_v",
|
||||
f"single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
[
|
||||
(hidden_size, hidden_size),
|
||||
(hidden_size, hidden_size),
|
||||
(hidden_size, hidden_size),
|
||||
(mlp_hidden_dim, hidden_size),
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
allow_missing_keys=True,
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
|
||||
@@ -7,5 +7,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
|
||||
|
||||
34
invokeai/backend/lora/layers/reshape_weight_layer.py
Normal file
34
invokeai/backend/lora/layers/reshape_weight_layer.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class ReshapeWeightLayer(LoRALayerBase):
|
||||
# TODO: Just everything in this class
|
||||
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||
self.manual_scale = scale
|
||||
|
||||
def scale(self):
|
||||
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return orig_weight
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
if self.weight is not None:
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
if self.manual_scale is not None:
|
||||
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.manual_scale)
|
||||
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
29
invokeai/backend/lora/layers/set_parameter_layer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetParameterLayer(LoRALayerBase):
|
||||
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||
super().__init__(None, None)
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight - orig_weight
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.weight)
|
||||
@@ -9,6 +9,7 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
|
||||
133
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
133
invokeai/backend/lora/lora_layer_wrappers.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRASidecarWrapper(torch.nn.Module):
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[AnyLoRALayer], lora_weights: list[float]):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
self._lora_weights = lora_weights
|
||||
|
||||
@property
|
||||
def orig_module(self) -> torch.nn.Module:
|
||||
return self._orig_module
|
||||
|
||||
def add_lora_layer(self, lora_layer: AnyLoRALayer, lora_weight: float):
|
||||
self._lora_layers.append(lora_layer)
|
||||
self._lora_weights.append(lora_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_lora_patched_parameters(
|
||||
self, orig_params: dict[str, torch.Tensor], lora_layers: list[AnyLoRALayer], lora_weights: list[float]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
for lora_layer, lora_weight in zip(lora_layers, lora_weights, strict=True):
|
||||
layer_params = lora_layer.get_parameters(self._orig_module)
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if orig_params[param_name].shape != param_weight.shape:
|
||||
param_weight = param_weight.reshape(orig_params[param_name].shape)
|
||||
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight * (lora_layer.scale() * lora_weight)
|
||||
else:
|
||||
params[param_name] += param_weight * (lora_layer.scale() * lora_weight)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class LoRALinearWrapper(LoRASidecarWrapper):
|
||||
def _lora_linear_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
def _concatenated_lora_forward(
|
||||
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Split the LoRA layers into those that have optimized implementations and those that don't.
|
||||
optimized_layer_types = (LoRALayer, ConcatenatedLoRALayer)
|
||||
optimized_layers = [
|
||||
(layer, weight)
|
||||
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
|
||||
if isinstance(layer, optimized_layer_types)
|
||||
]
|
||||
non_optimized_layers = [
|
||||
(layer, weight)
|
||||
for layer, weight in zip(self._lora_layers, self._lora_weights, strict=True)
|
||||
if not isinstance(layer, optimized_layer_types)
|
||||
]
|
||||
|
||||
# First, calculate the residual for LoRA layers for which there is an optimized implementation.
|
||||
residual = None
|
||||
for lora_layer, lora_weight in optimized_layers:
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
added_residual = self._lora_linear_forward(input, lora_layer, lora_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
added_residual = self._concatenated_lora_forward(input, lora_layer, lora_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA layer type: {type(lora_layer)}")
|
||||
|
||||
if residual is None:
|
||||
residual = added_residual
|
||||
else:
|
||||
residual += added_residual
|
||||
|
||||
# Next, calculate the residuals for the LoRA layers for which there is no optimized implementation.
|
||||
if non_optimized_layers:
|
||||
unoptimized_layers, unoptimized_weights = zip(*non_optimized_layers, strict=True)
|
||||
params = self._get_lora_patched_parameters(
|
||||
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||
lora_layers=unoptimized_layers,
|
||||
lora_weights=unoptimized_weights,
|
||||
)
|
||||
added_residual = torch.nn.functional.linear(input, params["weight"], params.get("bias", None))
|
||||
if residual is None:
|
||||
residual = added_residual
|
||||
else:
|
||||
residual += added_residual
|
||||
|
||||
return self.orig_module(input) + residual
|
||||
|
||||
|
||||
class LoRAConv1dWrapper(LoRASidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters(
|
||||
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||
lora_layers=self._lora_layers,
|
||||
lora_weights=self._lora_weights,
|
||||
)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(input, params["weight"], params.get("bias", None))
|
||||
|
||||
|
||||
class LoRAConv2dWrapper(LoRASidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
params = self._get_lora_patched_parameters(
|
||||
orig_params={"weight": self._orig_module.weight, "bias": self._orig_module.bias},
|
||||
lora_layers=self._lora_layers,
|
||||
lora_weights=self._lora_weights,
|
||||
)
|
||||
return self.orig_module(input) + torch.nn.functional.conv2d(input, params["weight"], params.get("bias", None))
|
||||
@@ -4,19 +4,126 @@ from typing import Dict, Iterable, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
from invokeai.backend.lora.lora_layer_wrappers import (
|
||||
LoRAConv1dWrapper,
|
||||
LoRAConv2dWrapper,
|
||||
LoRALinearWrapper,
|
||||
LoRASidecarWrapper,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_smart_lora_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Apply 'smart' LoRA patching that chooses whether to use direct patching or a sidecar wrapper for each module."""
|
||||
|
||||
# original_weights are stored for unpatching layers that are directly patched.
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
# original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper.
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_smart_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Restore directly patched layers.
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
# Restore LoRASidecarWrapper modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_smart_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
|
||||
patching or a sidecar wrapper for each module.
|
||||
"""
|
||||
if patch_weight == 0:
|
||||
return
|
||||
|
||||
# If the layer keys contain a dot, then they are not flattened, and can be directly used to access model
|
||||
# submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been
|
||||
# replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly
|
||||
# without searching, but some legacy code still uses flattened keys.
|
||||
layer_keys_are_flattened = "." not in next(iter(patch.layers.keys()))
|
||||
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for layer_key, layer in patch.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Decide whether to use direct patching or a sidecar wrapper.
|
||||
# Direct patching is preferred, because it results in better runtime speed.
|
||||
# Reasons to use sidecar patching:
|
||||
# - The module is already wrapped in a LoRASidecarWrapper.
|
||||
# - The module is quantized.
|
||||
# - The module is on the CPU (and we don't want to store a second full copy of the original weights on the
|
||||
# CPU, since this would double the RAM usage)
|
||||
# NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller
|
||||
# and that the caller will use the 'apply_lora_wrapper_patches' method if the layer is quantized.
|
||||
# TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows
|
||||
# forcing full patching even on the CPU?
|
||||
if isinstance(module, LoRASidecarWrapper) or LoRAPatcher._is_any_part_of_layer_on_cpu(module):
|
||||
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
LoRAPatcher._apply_lora_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool:
|
||||
return any(p.device.type == "cpu" for p in layer.parameters())
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
@@ -40,7 +147,7 @@ class LoRAPatcher:
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
LoRAPatcher._apply_lora_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -52,11 +159,12 @@ class LoRAPatcher:
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
cur_param = model.get_parameter(param_key)
|
||||
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
def _apply_lora_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
@@ -91,48 +199,84 @@ class LoRAPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
LoRAPatcher._apply_lora_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
layer_scale = layer.scale()
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_layer_patch(
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: AnyLoRALayer,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = patch.scale()
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
patch.to(device=device)
|
||||
patch.to(dtype=torch.float32)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in patch.get_parameters(module_to_patch).items():
|
||||
param_key = module_to_patch_key + "." + param_name
|
||||
module_param = module_to_patch.get_parameter(param_name)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
if module_param.nelement() == lora_param_weight.nelement():
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
else:
|
||||
# This condition was added to handle layers in FLUX control LoRAs.
|
||||
# TODO(ryand): Move the weight update into the LoRA layer so that the LoRAPatcher doesn't need
|
||||
# to worry about this?
|
||||
expanded_weight = torch.zeros_like(
|
||||
lora_param_weight, dtype=module_param.dtype, device=module_param.device
|
||||
)
|
||||
slices = tuple(slice(0, dim) for dim in module_param.shape)
|
||||
expanded_weight[slices] = module_param
|
||||
setattr(
|
||||
module,
|
||||
param_name,
|
||||
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||
)
|
||||
module_param = expanded_weight
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
def apply_lora_wrapper_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some
|
||||
overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any
|
||||
quantization format.
|
||||
"""Apply one or more LoRA wrapper patches to a model within a context manager. Wrapper patches incur some
|
||||
runtime overhead compared to normal LoRA patching, but they enable:
|
||||
- LoRA layers to be applied to quantized models
|
||||
- LoRA layers to be applied to CPU layers without needing to store a full copy of the original weights (i.e.
|
||||
avoid doubling the memory requirements).
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to patch.
|
||||
@@ -140,14 +284,11 @@ class LoRAPatcher:
|
||||
associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory
|
||||
all at once.
|
||||
prefix (str): The keys in the patches will be filtered to only include weights with this prefix.
|
||||
dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model,
|
||||
since the sidecar layers are typically applied on top of quantized layers whose weight dtype is
|
||||
different from their compute dtype.
|
||||
"""
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_lora_sidecar_patch(
|
||||
LoRAPatcher._apply_lora_wrapper_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -165,7 +306,7 @@ class LoRAPatcher:
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
def _apply_lora_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch_weight: float,
|
||||
@@ -173,7 +314,7 @@ class LoRAPatcher:
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA sidecar patch to a model."""
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
|
||||
if patch_weight == 0:
|
||||
return
|
||||
@@ -194,28 +335,47 @@ class LoRAPatcher:
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
LoRAPatcher._apply_lora_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Replace the original module with a LoRASidecarModule if it has not already been done.
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
assert isinstance(module, LoRASidecarModule)
|
||||
lora_sidecar_module = module
|
||||
else:
|
||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
||||
lora_sidecar_module = LoRASidecarModule(module, [])
|
||||
original_modules[module_key] = module
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_lora_layer_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: AnyLoRALayer,
|
||||
patch_weight: float,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
|
||||
# Replace the original module with a LoRASidecarWrapper if it has not already been done.
|
||||
if not isinstance(module_to_patch, LoRASidecarWrapper):
|
||||
lora_wrapper_layer = LoRAPatcher._initialize_lora_wrapper_layer(module_to_patch)
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_to_patch_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_wrapper_layer)
|
||||
orig_module = module_to_patch
|
||||
else:
|
||||
assert module_to_patch_key in original_modules
|
||||
lora_wrapper_layer = module_to_patch
|
||||
orig_module = module_to_patch.orig_module
|
||||
|
||||
# Add the LoRA sidecar layer to the LoRASidecarModule.
|
||||
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
|
||||
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||
patch.to(device=orig_module.weight.device, dtype=dtype)
|
||||
|
||||
# Add the LoRA wrapper layer to the LoRASidecarWrapper.
|
||||
lora_wrapper_layer.add_lora_layer(patch, patch_weight)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||
@@ -236,17 +396,13 @@ class LoRAPatcher:
|
||||
raise ValueError(f"Invalid module key: {module_key}")
|
||||
|
||||
@staticmethod
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
# TODO(ryand): Add support for more original layer types and LoRA layer types.
|
||||
if isinstance(orig_layer, torch.nn.Linear) or (
|
||||
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
|
||||
):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
def _initialize_lora_wrapper_layer(orig_layer: torch.nn.Module):
|
||||
if isinstance(orig_layer, torch.nn.Linear):
|
||||
return LoRALinearWrapper(orig_layer, [], [])
|
||||
elif isinstance(orig_layer, torch.nn.Conv1d):
|
||||
return LoRAConv1dWrapper(orig_layer, [], [])
|
||||
elif isinstance(orig_layer, torch.nn.Conv2d):
|
||||
return LoRAConv2dWrapper(orig_layer, [], [])
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
|
||||
|
||||
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
concatenated_lora_layer: ConcatenatedLoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._concatenated_lora_layer = concatenated_lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in self._concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= self._weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert self._concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._concatenated_lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_layer: LoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._lora_layer = lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
||||
if self._lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
||||
x *= self._weight * self._lora_layer.scale()
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRASidecarModule(torch.nn.Module):
|
||||
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
|
||||
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self.orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.orig_module(input)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(input)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._orig_module.to(device=device, dtype=dtype)
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
@@ -66,6 +67,7 @@ class ModelType(str, Enum):
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
StructuralLoRa = "structural_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@@ -83,8 +85,10 @@ class SubModelType(str, Enum):
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
@@ -92,6 +96,13 @@ class SubModelType(str, Enum):
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
@@ -147,6 +158,17 @@ class ModelSourceType(str, Enum):
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
variant: AnyVariant = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class MainModelDefaultSettings(BaseModel):
|
||||
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||
@@ -193,6 +215,9 @@ class ModelConfigBase(BaseModel):
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
|
||||
description="Loadable submodels in this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
@@ -249,6 +274,18 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class StructuralLoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for Structural LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
@@ -335,7 +372,7 @@ class MainConfigBase(ModelConfigBase):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
variant: AnyVariant = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
@@ -419,12 +456,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.G
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
|
||||
|
||||
|
||||
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
|
||||
"""Model config for CLIP-L Embeddings."""
|
||||
|
||||
variant: ClipVariantType = ClipVariantType.L
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
@@ -490,6 +548,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
@@ -501,6 +560,8 @@ AnyModelConfig = Annotated[
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
|
||||
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
@@ -35,6 +35,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
self._torch_device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
|
||||
@@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.float32
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
|
||||
return model
|
||||
|
||||
@@ -128,9 +136,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
)
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -172,9 +180,9 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -13,8 +13,9 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
@@ -32,6 +33,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
@@ -75,7 +77,10 @@ class LoRALoader(ModelLoader):
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
@@ -117,8 +111,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
load_class = load_classes[config.base][config.variant]
|
||||
except KeyError as e:
|
||||
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
|
||||
prediction_type = config.prediction_type.value
|
||||
upcast_attention = config.upcast_attention
|
||||
|
||||
# Without SilenceWarnings we get log messages like this:
|
||||
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
|
||||
@@ -129,13 +121,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
# ['text_model.embeddings.position_ids']
|
||||
|
||||
with SilenceWarnings():
|
||||
pipeline = load_class.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=False,
|
||||
)
|
||||
pipeline = load_class.from_single_file(config.path, torch_dtype=self._torch_dtype)
|
||||
|
||||
if not submodel_type:
|
||||
return pipeline
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
@@ -18,10 +18,12 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlab
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
@@ -33,8 +35,15 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubmodelDefinition,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
get_clip_variant_type,
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -112,6 +121,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
@@ -122,8 +132,12 @@ class ModelProbe(object):
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
"SD3Transformer2DModel": ModelType.Main,
|
||||
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
|
||||
}
|
||||
|
||||
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
@@ -170,7 +184,10 @@ class ModelProbe(object):
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
|
||||
fields["variant"] = (
|
||||
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
|
||||
)
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
@@ -217,6 +234,10 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
get_submodels = getattr(probe, "get_submodels", None)
|
||||
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
|
||||
fields["submodels"] = get_submodels()
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@@ -238,6 +259,18 @@ class ModelProbe(object):
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
if isinstance(ckpt, dict) and "img_in.lora_A.weight" in ckpt and "img_in.lora_B.weight" in ckpt:
|
||||
tensor_a, tensor_b = ckpt["img_in.lora_A.weight"], ckpt["img_in.lora_B.weight"]
|
||||
if (
|
||||
tensor_a is not None
|
||||
and isinstance(tensor_a, torch.Tensor)
|
||||
and tensor_a.shape[1] == 128
|
||||
and tensor_b is not None
|
||||
and isinstance(tensor_b, torch.Tensor)
|
||||
and tensor_b.shape[0] == 3072
|
||||
):
|
||||
return ModelType.StructuralLoRa
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
@@ -449,7 +482,7 @@ class ModelProbe(object):
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
@@ -465,6 +498,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"lineart anime": "lineart_anime_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"soft": "hed_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"hed": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
@@ -603,8 +637,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return ModelFormat.LyCORIS
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
if (
|
||||
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
|
||||
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
|
||||
or is_state_dict_likely_flux_control(self.checkpoint)
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
@@ -747,18 +783,33 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
config_path = self.model_path / "unet" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
# Handle pipelines with a transformer (i.e. SD3).
|
||||
config_path = self.model_path / "transformer" / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as file:
|
||||
transformer_conf = json.load(file)
|
||||
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
@@ -770,6 +821,23 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
|
||||
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
|
||||
submodels: Dict[SubModelType, SubmodelDefinition] = {}
|
||||
for key, value in config.items():
|
||||
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
|
||||
continue
|
||||
model_loader = str(value[1])
|
||||
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
|
||||
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=(self.model_path / key).resolve().as_posix(),
|
||||
model_type=model_type,
|
||||
variant=variant_func and variant_func((self.model_path / key).as_posix()),
|
||||
)
|
||||
|
||||
return submodels
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
@@ -993,6 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
||||
@@ -140,6 +140,22 @@ flux_dev = StarterModel(
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-medium",
|
||||
description="Medium SD3.5 Model: ~15GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
sd35_large = StarterModel(
|
||||
name="SD3.5 Large",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3.5-large",
|
||||
description="Large SD3.5 Model: ~19G",
|
||||
type=ModelType.Main,
|
||||
dependencies=[],
|
||||
)
|
||||
cyberrealistic_sd1 = StarterModel(
|
||||
name="CyberRealistic v4.1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
@@ -282,13 +298,12 @@ ip_adapter_sdxl = StarterModel(
|
||||
previous_names=["IP Adapter SDXL"],
|
||||
)
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter)",
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/resolve/main/ip_adapter.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
previous_names=["XLabs FLUX IP-Adapter"],
|
||||
)
|
||||
# endregion
|
||||
# region ControlNet
|
||||
@@ -570,6 +585,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
flux_dev,
|
||||
sd35_medium,
|
||||
sd35_large,
|
||||
cyberrealistic_sd1,
|
||||
rev_animated_sd1,
|
||||
dreamshaper_8_sd1,
|
||||
|
||||
@@ -8,6 +8,7 @@ import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, torch.Tensor]:
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
path_str = path.as_posix() if isinstance(path, Path) else path
|
||||
@@ -51,16 +52,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
elif str(path).endswith(".gguf"):
|
||||
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
if str(path).endswith(".gguf"):
|
||||
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
@@ -165,3 +165,25 @@ def convert_bundle_to_flux_transformer_checkpoint(
|
||||
del transformer_state_dict[k]
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
|
||||
try:
|
||||
path = Path(location)
|
||||
config_path = path / "config.json"
|
||||
if not config_path.exists():
|
||||
config_path = path / "text_encoder" / "config.json"
|
||||
if not config_path.exists():
|
||||
return ClipVariantType.L
|
||||
with open(config_path) as file:
|
||||
clip_conf = json.load(file)
|
||||
hidden_size = clip_conf.get("hidden_size", -1)
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return ClipVariantType.L
|
||||
except Exception:
|
||||
return ClipVariantType.L
|
||||
|
||||
@@ -85,6 +85,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
safetensors_detected = False
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
@@ -119,19 +120,27 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if safetensors_detected and path.suffix == ".bin":
|
||||
continue
|
||||
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
safetensors_detected = True
|
||||
if parent in subfolder_weights:
|
||||
subfolder_weights[parent] = [sfc for sfc in subfolder_weights[parent] if sfc.path.suffix != ".bin"]
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||
):
|
||||
if (
|
||||
variant is not ModelRepoVariant.Default
|
||||
and candidate_variant_label
|
||||
and candidate_variant_label.startswith(f".{variant.value}")
|
||||
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
@@ -146,7 +155,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# Check if at least one of the files has the explicit fp16 variant.
|
||||
at_least_one_fp16 = False
|
||||
for candidate in candidate_list:
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
|
||||
at_least_one_fp16 = True
|
||||
break
|
||||
|
||||
@@ -162,7 +171,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
# candidate.
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
|
||||
match = re.match(pattern, highest_score_candidate.path.as_posix())
|
||||
if match:
|
||||
for candidate in candidate_list:
|
||||
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
|
||||
match.group(2)
|
||||
):
|
||||
result.add(candidate.path)
|
||||
else:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
|
||||
58
invokeai/backend/sd3/extensions/inpaint_extension.py
Normal file
58
invokeai/backend/sd3/extensions/inpaint_extension.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with SD3."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0).
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents.
|
||||
"""
|
||||
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4
|
||||
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:]
|
||||
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
|
||||
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
|
||||
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
|
||||
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
|
||||
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
|
||||
# threshold worked well.
|
||||
# We use a small epsilon to avoid any potential issues with floating point precision.
|
||||
eps = 1e-4
|
||||
mask_gradient_t_cutoff = 0.5
|
||||
if t_prev > mask_gradient_t_cutoff:
|
||||
# Early in the denoising process, use the inpaint mask as-is.
|
||||
return self._inpaint_mask
|
||||
else:
|
||||
# After the cut-off, promote all non-zero mask values to 1.0.
|
||||
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
|
||||
|
||||
return mask
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, t_prev: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
|
||||
mask = self._apply_mask_gradient_adjustment(t_prev)
|
||||
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
|
||||
@@ -499,6 +499,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||
|
||||
# Hack: force compatibility with irregular resolutions by padding the feature map with zeros
|
||||
for idx, tensor in enumerate(accum_adapter_state):
|
||||
# The tensor size is supposed to be some integer downscale factor of the latents size.
|
||||
# Internally, the unet will pad the latents before downscaling between levels when it is no longer divisible by its downscale factor.
|
||||
# If the latent size does not scale down evenly, we need to pad the tensor so that it matches the the downscaled padded latents later on.
|
||||
scale_factor = latents.size()[-1] // tensor.size()[-1]
|
||||
required_padding_width = math.ceil(latents.size()[-1] / scale_factor) - tensor.size()[-1]
|
||||
required_padding_height = math.ceil(latents.size()[-2] / scale_factor) - tensor.size()[-2]
|
||||
tensor = torch.nn.functional.pad(
|
||||
tensor,
|
||||
(0, required_padding_width, 0, required_padding_height, 0, 0, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
)
|
||||
accum_adapter_state[idx] = tensor
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
# Handle inpainting models.
|
||||
|
||||
@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SD3ConditioningInfo:
|
||||
clip_l_pooled_embeds: torch.Tensor
|
||||
clip_l_embeds: torch.Tensor
|
||||
clip_g_pooled_embeds: torch.Tensor
|
||||
clip_g_embeds: torch.Tensor
|
||||
t5_embeds: torch.Tensor | None
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
|
||||
if self.t5_embeds is not None:
|
||||
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||
conditionings: (
|
||||
List[BasicConditioningInfo]
|
||||
| List[SDXLConditioningInfo]
|
||||
| List[FLUXConditioningInfo]
|
||||
| List[SD3ConditioningInfo]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreviewExt(ExtensionBase):
|
||||
def initial_preview(self, ctx: DenoiseContext):
|
||||
self.callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
step=0,
|
||||
order=ctx.scheduler.order,
|
||||
total_steps=len(ctx.inputs.timesteps),
|
||||
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user