mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 01:17:56 -05:00
Compare commits
280 Commits
psychedeli
...
v5.5.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
948ecf9333 | ||
|
|
1038f7bcab | ||
|
|
c7d9e2d62a | ||
|
|
11c3a2e15d | ||
|
|
9e3ca383ec | ||
|
|
bda83c2634 | ||
|
|
525cb38c71 | ||
|
|
a9a6720bad | ||
|
|
858bf9cf8c | ||
|
|
74a29c3735 | ||
|
|
6fc6be3aa0 | ||
|
|
174ea021a6 | ||
|
|
50b804e087 | ||
|
|
23270d7dfe | ||
|
|
39e6f6d53f | ||
|
|
c154d833b9 | ||
|
|
899a00af62 | ||
|
|
7c9ecdb362 | ||
|
|
4a5255611b | ||
|
|
b5b39db304 | ||
|
|
2cb5743cc5 | ||
|
|
64ee8d491e | ||
|
|
d70d48de45 | ||
|
|
3f8636330f | ||
|
|
0c2f96daf1 | ||
|
|
c9b2cce627 | ||
|
|
401fb392b8 | ||
|
|
594511cf4a | ||
|
|
d764aa4a2a | ||
|
|
ea34726329 | ||
|
|
9b615e0de7 | ||
|
|
a463e97269 | ||
|
|
b272d46056 | ||
|
|
4d5f74c05b | ||
|
|
dd09509dbd | ||
|
|
7fad4c9491 | ||
|
|
b820862eab | ||
|
|
c604a0956e | ||
|
|
9369b39a12 | ||
|
|
80f64abd1e | ||
|
|
37e3089457 | ||
|
|
fe09f2d27a | ||
|
|
e7e3f7e144 | ||
|
|
606d58d7db | ||
|
|
c76a448846 | ||
|
|
46133b5656 | ||
|
|
ac28370fd2 | ||
|
|
1e0552c813 | ||
|
|
e2451ef5ca | ||
|
|
443d838fd0 | ||
|
|
3a8a5442ea | ||
|
|
808e3770d3 | ||
|
|
2b441d6a2d | ||
|
|
58de93a89e | ||
|
|
1eede4315e | ||
|
|
8ea697d733 | ||
|
|
693d42661c | ||
|
|
41664f88db | ||
|
|
42f8d6aa11 | ||
|
|
5f41a69665 | ||
|
|
7da90a9b6b | ||
|
|
440185cc40 | ||
|
|
26edc71268 | ||
|
|
a4bed7aee3 | ||
|
|
5fcd76a712 | ||
|
|
516ffa641c | ||
|
|
d84adfd39f | ||
|
|
ac82f73dbe | ||
|
|
70811d0bd0 | ||
|
|
e0344a302c | ||
|
|
92b0d89b70 | ||
|
|
da213e4638 | ||
|
|
246b59f148 | ||
|
|
046d19446c | ||
|
|
040551d4fb | ||
|
|
f53da60b84 | ||
|
|
5a035dd19f | ||
|
|
f3b253987f | ||
|
|
25ff7918e8 | ||
|
|
09fc60acb0 | ||
|
|
6f55f2c723 | ||
|
|
03b815c884 | ||
|
|
9cecdd17eb | ||
|
|
6b0f7ab57c | ||
|
|
c805e38da2 | ||
|
|
2c1de0f07d | ||
|
|
261d5ab488 | ||
|
|
ca571cd7a9 | ||
|
|
4c94d41fa9 | ||
|
|
4036244ee9 | ||
|
|
d06232d9ba | ||
|
|
bacbdfb8fc | ||
|
|
59f42f4682 | ||
|
|
a636ac2899 | ||
|
|
bd478360d9 | ||
|
|
ac0db07649 | ||
|
|
b7132ce9e7 | ||
|
|
90f30e7748 | ||
|
|
6b86a66bc7 | ||
|
|
aa97e626e9 | ||
|
|
c90736093f | ||
|
|
0bff4ace1b | ||
|
|
5eb382074e | ||
|
|
46aa930526 | ||
|
|
3305bad0c2 | ||
|
|
13703d8f55 | ||
|
|
60d838d0a5 | ||
|
|
2a157a44bf | ||
|
|
d61b5833c2 | ||
|
|
c094838c6a | ||
|
|
2d334c8dd8 | ||
|
|
a6be26e174 | ||
|
|
f8c7adddd0 | ||
|
|
17da1d92e9 | ||
|
|
1cc57a4854 | ||
|
|
3993fae331 | ||
|
|
1446526d55 | ||
|
|
62c024e725 | ||
|
|
1e92bb4e94 | ||
|
|
db6398fdf6 | ||
|
|
ebd73a2ac2 | ||
|
|
8ee95cab00 | ||
|
|
d1184201a8 | ||
|
|
5887891654 | ||
|
|
765ca4e004 | ||
|
|
159b00a490 | ||
|
|
3fbf6f2d2a | ||
|
|
931fca7cd1 | ||
|
|
db84a3a5d4 | ||
|
|
ca8313e805 | ||
|
|
df849035ee | ||
|
|
8d97fe69ca | ||
|
|
9044e53a9b | ||
|
|
6012b0f912 | ||
|
|
bb0ed5dc8a | ||
|
|
021552fd81 | ||
|
|
be73dbba92 | ||
|
|
db9c0cad7c | ||
|
|
54b7f9a063 | ||
|
|
7d488a5352 | ||
|
|
4d7667f63d | ||
|
|
08704ee8ec | ||
|
|
5910892c33 | ||
|
|
46a09d9e90 | ||
|
|
df0c7d73f3 | ||
|
|
3905c97e32 | ||
|
|
0be796a808 | ||
|
|
7dd33b0f39 | ||
|
|
484aaf1595 | ||
|
|
c276b60af9 | ||
|
|
5d8dd6e26e | ||
|
|
5bca68d873 | ||
|
|
64364e7911 | ||
|
|
6565cea039 | ||
|
|
3ebd8d6c07 | ||
|
|
e970185161 | ||
|
|
fa5653cdf7 | ||
|
|
9a7b000995 | ||
|
|
3a27242838 | ||
|
|
8cfb032051 | ||
|
|
06a9d4e2b2 | ||
|
|
ed46acee79 | ||
|
|
b54463d294 | ||
|
|
faee79dc95 | ||
|
|
965cd76e33 | ||
|
|
e5e8cbf34c | ||
|
|
3412a52594 | ||
|
|
e01f66b026 | ||
|
|
53abdde242 | ||
|
|
94c088300f | ||
|
|
3741a6f5e0 | ||
|
|
059336258f | ||
|
|
2c23b8414c | ||
|
|
271cc52c80 | ||
|
|
20356c0746 | ||
|
|
e44458609f | ||
|
|
69d86a7696 | ||
|
|
56db1a9292 | ||
|
|
cf50e5eeee | ||
|
|
c9c07968d2 | ||
|
|
97d0757176 | ||
|
|
0f51b677a9 | ||
|
|
56ca94c3a9 | ||
|
|
28d169f859 | ||
|
|
92f71d99ee | ||
|
|
0764c02b1d | ||
|
|
081c7569fe | ||
|
|
20f6532ee8 | ||
|
|
b9e8910478 | ||
|
|
ded8391e3c | ||
|
|
e9dd2c396a | ||
|
|
0d86de0cb5 | ||
|
|
bad1149504 | ||
|
|
fda7aaa7ca | ||
|
|
85c616fa34 | ||
|
|
549f4e9794 | ||
|
|
ef8ededd2f | ||
|
|
1948ffe106 | ||
|
|
c70f4404c4 | ||
|
|
b157ae928c | ||
|
|
7a0871992d | ||
|
|
b38e2e14f4 | ||
|
|
7c0e70ec84 | ||
|
|
a89ae9d2bf | ||
|
|
ad1fcb3f07 | ||
|
|
87d74b910b | ||
|
|
7ad1c297a4 | ||
|
|
fbc629faa6 | ||
|
|
7baa6b3c09 | ||
|
|
53d482bade | ||
|
|
5aca04b51b | ||
|
|
ea8787c8ff | ||
|
|
cead2c4445 | ||
|
|
f76ac1808c | ||
|
|
f01210861b | ||
|
|
f757f23ef0 | ||
|
|
872a6ef209 | ||
|
|
4267e5ffc4 | ||
|
|
a69c5ff9ef | ||
|
|
3ebd8d7d1b | ||
|
|
1fd80d54a4 | ||
|
|
991f63e455 | ||
|
|
6a1efd3527 | ||
|
|
0eadc0dd9e | ||
|
|
481423d678 | ||
|
|
89ede0aef3 | ||
|
|
359bdee9c6 | ||
|
|
0e6fba3763 | ||
|
|
652502d7a6 | ||
|
|
91d981a49e | ||
|
|
24f61d21b2 | ||
|
|
eb9a4177c5 | ||
|
|
3c43351a5b | ||
|
|
b1359b6dff | ||
|
|
bddccf6d2f | ||
|
|
21ffaab2a2 | ||
|
|
1e969f938f | ||
|
|
9c6c86ee4f | ||
|
|
6b53a48b48 | ||
|
|
c813fa3fc0 | ||
|
|
a08e61184a | ||
|
|
a0d62a5f41 | ||
|
|
616c0f11e1 | ||
|
|
e1626a4e49 | ||
|
|
6ab891a319 | ||
|
|
492de41316 | ||
|
|
c064efc866 | ||
|
|
1a0885bfb1 | ||
|
|
e8b202d0a5 | ||
|
|
c6fc82f756 | ||
|
|
9a77e951d2 | ||
|
|
8bd4207a27 | ||
|
|
0bb601aaf7 | ||
|
|
2da25a0043 | ||
|
|
51d0931898 | ||
|
|
357b68d1ba | ||
|
|
d9ddb6c32e | ||
|
|
ad02a99a83 | ||
|
|
b707dafc7b | ||
|
|
02906c8f5d | ||
|
|
8538e508f1 | ||
|
|
8c333ffd14 | ||
|
|
72ace5fdff | ||
|
|
9b7583fc84 | ||
|
|
989eee338e | ||
|
|
acc3d7b91b | ||
|
|
49de868658 | ||
|
|
b1702c7d90 | ||
|
|
e49e19ea13 | ||
|
|
c9f91f391e | ||
|
|
4cb6b2b701 | ||
|
|
7d132ea148 | ||
|
|
1088accd91 | ||
|
|
8d237d8f8b | ||
|
|
0c86a3232d | ||
|
|
dbfb0359cb | ||
|
|
b4c2aa596b | ||
|
|
87e89b7995 | ||
|
|
9b089430e2 | ||
|
|
f2b0025958 |
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Only the latest version of Invoke will receive security updates.
|
||||
We do not currently maintain multiple versions of the application with updates.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
To report a vulnerability, contact the Invoke team directly at security@invoke.ai
|
||||
|
||||
At this time, we do not maintain a formal bug bounty program.
|
||||
|
||||
You can also share identified security issues with our team on huntr.com
|
||||
@@ -2,29 +2,42 @@
|
||||
|
||||
## Builder stage
|
||||
|
||||
FROM library/ubuntu:23.04 AS builder
|
||||
FROM library/ubuntu:24.04 AS builder
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt-get install -y \
|
||||
git \
|
||||
python3-venv \
|
||||
python3-pip \
|
||||
build-essential
|
||||
build-essential \
|
||||
git
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
# Install `uv` for package management
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV UV_COMPILE_BYTECODE=1
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ARG GPU_DRIVER=cuda
|
||||
ARG TARGETPLATFORM="linux/amd64"
|
||||
# unused but available
|
||||
ARG BUILDPLATFORM
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
|
||||
RUN mkdir -p ${VIRTUAL_ENV} && \
|
||||
mkdir -p ${INVOKEAI_SRC} && \
|
||||
chmod -R a+w /opt
|
||||
USER ubuntu
|
||||
|
||||
# Install python and create the venv
|
||||
RUN uv python install ${PYTHON_VERSION} && \
|
||||
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
|
||||
|
||||
WORKDIR ${INVOKEAI_SRC}
|
||||
COPY invokeai ./invokeai
|
||||
COPY pyproject.toml ./
|
||||
|
||||
@@ -32,25 +45,18 @@ COPY pyproject.toml ./
|
||||
# the local working copy can be bind-mounted into the image
|
||||
# at path defined by ${INVOKEAI_SRC}
|
||||
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
|
||||
# x86_64/CUDA is default
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m venv ${VIRTUAL_ENV} &&\
|
||||
# x86_64/CUDA is the default
|
||||
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
|
||||
fi &&\
|
||||
fi && \
|
||||
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
|
||||
|
||||
# xformers + triton fails to install on arm64
|
||||
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
||||
pip install $extra_index_url_arg -e ".[xformers]"; \
|
||||
else \
|
||||
pip install $extra_index_url_arg -e "."; \
|
||||
fi
|
||||
|
||||
# #### Build the Web UI ------------------------------------
|
||||
#### Build the Web UI ------------------------------------
|
||||
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
@@ -66,7 +72,7 @@ RUN npx vite build
|
||||
|
||||
#### Runtime stage ---------------------------------------
|
||||
|
||||
FROM library/ubuntu:23.04 AS runtime
|
||||
FROM library/ubuntu:24.04 AS runtime
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
@@ -83,17 +89,16 @@ RUN apt update && apt install -y --no-install-recommends \
|
||||
gosu \
|
||||
magic-wormhole \
|
||||
libglib2.0-0 \
|
||||
libgl1-mesa-glx \
|
||||
python3-venv \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
libglx-mesa0 \
|
||||
build-essential \
|
||||
libopencv-dev \
|
||||
libstdc++-10-dev &&\
|
||||
apt-get clean && apt-get autoclean
|
||||
|
||||
|
||||
ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv
|
||||
ENV PYTHON_VERSION=3.11
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV INVOKEAI_HOST=0.0.0.0
|
||||
ENV INVOKEAI_PORT=9090
|
||||
@@ -101,6 +106,14 @@ ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
# Install `uv` for package management
|
||||
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
|
||||
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
|
||||
USER ubuntu
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
USER root
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
COPY --link --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
@@ -115,7 +128,7 @@ WORKDIR ${INVOKEAI_SRC}
|
||||
|
||||
# build patchmatch
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python3 -c "from patchmatch import patch_match"
|
||||
RUN python -c "from patchmatch import patch_match"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ set -e -o pipefail
|
||||
|
||||
USER_ID=${CONTAINER_UID:-1000}
|
||||
USER=ubuntu
|
||||
# if the user does not exist, create it. It is expected to be present on ubuntu >=24.x
|
||||
_=$(id ${USER} 2>&1) || useradd -u ${USER_ID} ${USER}
|
||||
# ensure the UID is correct
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
@@ -36,6 +39,8 @@ fi
|
||||
mkdir -p "${INVOKEAI_ROOT}"
|
||||
chown --recursive ${USER} "${INVOKEAI_ROOT}" || true
|
||||
cd "${INVOKEAI_ROOT}"
|
||||
export HF_HOME=${HF_HOME:-$INVOKEAI_ROOT/.cache/huggingface}
|
||||
export MPLCONFIGDIR=${MPLCONFIGDIR:-$INVOKEAI_ROOT/.matplotlib}
|
||||
|
||||
# Run the CMD as the Container User (not root).
|
||||
exec gosu ${USER} "$@"
|
||||
|
||||
@@ -50,7 +50,7 @@ Applications are built on top of the invoke framework. They should construct `in
|
||||
|
||||
### Web UI
|
||||
|
||||
The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/frontend` and the backend code is found in `/ldm/invoke/app/api_app.py` and `/ldm/invoke/app/api/`. The code is further organized as such:
|
||||
The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.tiangolo.com/) and [Socket.IO](https://socket.io/). The frontend code is found in `/invokeai/frontend` and the backend code is found in `/invokeai/app/api_app.py` and `/invokeai/app/api/`. The code is further organized as such:
|
||||
|
||||
| Component | Description |
|
||||
| --- | --- |
|
||||
@@ -62,7 +62,7 @@ The Web UI is built on top of an HTTP API built with [FastAPI](https://fastapi.t
|
||||
|
||||
### CLI
|
||||
|
||||
The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/ldm/invoke/app/cli_app.py`.
|
||||
The CLI is built automatically from invocation metadata, and also supports invocation piping and auto-linking. Code is available in `/invokeai/frontend/cli`.
|
||||
|
||||
## Invoke
|
||||
|
||||
@@ -70,7 +70,7 @@ The Invoke framework provides the interface to the underlying AI systems and is
|
||||
|
||||
### Invoker
|
||||
|
||||
The invoker (`/ldm/invoke/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services:
|
||||
The invoker (`/invokeai/app/services/invoker.py`) is the primary interface through which applications interact with the framework. Its primary purpose is to create, manage, and invoke sessions. It also maintains two sets of services:
|
||||
- **invocation services**, which are used by invocations to interact with core functionality.
|
||||
- **invoker services**, which are used by the invoker to manage sessions and manage the invocation queue.
|
||||
|
||||
@@ -82,12 +82,12 @@ The session graph does not support looping. This is left as an application probl
|
||||
|
||||
### Invocations
|
||||
|
||||
Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/ldm/invoke/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations.
|
||||
Invocations represent individual units of execution, with inputs and outputs. All invocations are located in `/invokeai/app/invocations`, and are all automatically discovered and made available in the applications. These are the primary way to expose new functionality in Invoke.AI, and the [implementation guide](INVOCATIONS.md) explains how to add new invocations.
|
||||
|
||||
### Services
|
||||
|
||||
Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/ldm/invoke/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import).
|
||||
Services provide invocations access AI Core functionality and other necessary functionality (e.g. image storage). These are available in `/invokeai/app/services`. As a general rule, new services should provide an interface as an abstract base class, and may provide a lightweight local implementation by default in their module. The goal for all services should be to enable the usage of different implementations (e.g. using cloud storage for image storage), but should not load any module dependencies unless that implementation has been used (i.e. don't import anything that won't be used, especially if it's expensive to import).
|
||||
|
||||
## AI Core
|
||||
|
||||
The AI Core is represented by the rest of the code base (i.e. the code outside of `/ldm/invoke/app/`).
|
||||
The AI Core is represented by the rest of the code base (i.e. the code outside of `/invokeai/app/`).
|
||||
|
||||
@@ -287,8 +287,8 @@ new Invocation ready to be used.
|
||||
|
||||
Once you've created a Node, the next step is to share it with the community! The
|
||||
best way to do this is to submit a Pull Request to add the Node to the
|
||||
[Community Nodes](nodes/communityNodes) list. If you're not sure how to do that,
|
||||
take a look a at our [contributing nodes overview](contributingNodes).
|
||||
[Community Nodes](../nodes/communityNodes.md) list. If you're not sure how to do that,
|
||||
take a look a at our [contributing nodes overview](../nodes/contributingNodes.md).
|
||||
|
||||
## Advanced
|
||||
|
||||
|
||||
@@ -9,20 +9,20 @@ model. These are the:
|
||||
configuration information. Among other things, the record service
|
||||
tracks the type of the model, its provenance, and where it can be
|
||||
found on disk.
|
||||
|
||||
|
||||
* _ModelInstallServiceBase_ A service for installing models to
|
||||
disk. It uses `DownloadQueueServiceBase` to download models and
|
||||
their metadata, and `ModelRecordServiceBase` to store that
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
|
||||
* _DownloadQueueServiceBase_
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
queue has special methods for downloading repo_id folders from
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
|
||||
* _ModelLoadServiceBase_
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
@@ -207,9 +207,9 @@ for use in the InvokeAI web server. Its signature is:
|
||||
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
@@ -363,7 +363,7 @@ functionality:
|
||||
|
||||
* Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
@@ -371,21 +371,21 @@ functionality:
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only HuggingFace).
|
||||
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
A default installer is created at InvokeAI api startup time and stored
|
||||
@@ -461,7 +461,7 @@ revision.
|
||||
`config` is an optional dict of values that will override the
|
||||
autoprobed values for model type, base, scheduler prediction type, and
|
||||
so forth. See [Model configuration and
|
||||
probing](#Model-configuration-and-probing) for details.
|
||||
probing](#model-configuration-and-probing) for details.
|
||||
|
||||
`access_token` is an optional access token for accessing resources
|
||||
that need authentication.
|
||||
@@ -494,7 +494,7 @@ source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', ac
|
||||
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
|
||||
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
@@ -504,7 +504,7 @@ for source in sources:
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# InvokeAI Backend Tests
|
||||
|
||||
We use `pytest` to run the backend python tests. (See [pyproject.toml](/pyproject.toml) for the default `pytest` options.)
|
||||
We use `pytest` to run the backend python tests. (See [pyproject.toml](https://github.com/invoke-ai/InvokeAI/blob/main/pyproject.toml) for the default `pytest` options.)
|
||||
|
||||
## Fast vs. Slow
|
||||
All tests are categorized as either 'fast' (no test annotation) or 'slow' (annotated with the `@pytest.mark.slow` decorator).
|
||||
@@ -33,7 +33,7 @@ pytest tests -m ""
|
||||
|
||||
## Test Organization
|
||||
|
||||
All backend tests are in the [`tests/`](/tests/) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
|
||||
All backend tests are in the [`tests/`](https://github.com/invoke-ai/InvokeAI/tree/main/tests) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
|
||||
|
||||
TODO: The above statement is aspirational. A re-organization of legacy tests is required to make it true.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## **What do I need to know to help?**
|
||||
|
||||
If you are looking to help with a code contribution, InvokeAI uses several different technologies under the hood: Python (Pydantic, FastAPI, diffusers) and Typescript (React, Redux Toolkit, ChakraUI, Mantine, Konva). Familiarity with StableDiffusion and image generation concepts is helpful, but not essential.
|
||||
If you are looking to help with a code contribution, InvokeAI uses several different technologies under the hood: Python (Pydantic, FastAPI, diffusers) and Typescript (React, Redux Toolkit, ChakraUI, Mantine, Konva). Familiarity with StableDiffusion and image generation concepts is helpful, but not essential.
|
||||
|
||||
|
||||
## **Get Started**
|
||||
@@ -12,7 +12,7 @@ To get started, take a look at our [new contributors checklist](newContributorCh
|
||||
Once you're setup, for more information, you can review the documentation specific to your area of interest:
|
||||
|
||||
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
|
||||
* #### [Frontend Documentation](https://github.com/invoke-ai/InvokeAI/tree/main/invokeai/frontend/web)
|
||||
* #### [Frontend Documentation](../frontend/index.md)
|
||||
* #### [Node Documentation](../INVOCATIONS.md)
|
||||
* #### [Local Development](../LOCAL_DEVELOPMENT.md)
|
||||
|
||||
@@ -20,15 +20,15 @@ Once you're setup, for more information, you can review the documentation specif
|
||||
|
||||
If you don't feel ready to make a code contribution yet, no problem! You can also help out in other ways, such as [documentation](documentation.md), [translation](translation.md) or helping support other users and triage issues as they're reported in GitHub.
|
||||
|
||||
There are two paths to making a development contribution:
|
||||
There are two paths to making a development contribution:
|
||||
|
||||
1. Choosing an open issue to address. Open issues can be found in the [Issues](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen) section of the InvokeAI repository. These are tagged by the issue type (bug, enhancement, etc.) along with the “good first issues” tag denoting if they are suitable for first time contributors.
|
||||
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||
2. Opening a new issue or feature to add. **Please make sure you have searched through existing issues before creating new ones.**
|
||||
|
||||
*Regardless of what you choose, please post in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord before you start development in order to confirm that the issue or feature is aligned with the current direction of the project. We value our contributors time and effort and want to ensure that no one’s time is being misspent.*
|
||||
|
||||
## Best Practices:
|
||||
## Best Practices:
|
||||
* Keep your pull requests small. Smaller pull requests are more likely to be accepted and merged
|
||||
* Comments! Commenting your code helps reviewers easily understand your contribution
|
||||
* Use Python and Typescript’s typing systems, and consider using an editor with [LSP](https://microsoft.github.io/language-server-protocol/) support to streamline development
|
||||
@@ -38,7 +38,7 @@ There are two paths to making a development contribution:
|
||||
|
||||
If you need help, you can ask questions in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord.
|
||||
|
||||
For frontend related work, **@psychedelicious** is the best person to reach out to.
|
||||
For frontend related work, **@psychedelicious** is the best person to reach out to.
|
||||
|
||||
For backend related work, please reach out to **@blessedcoolant**, **@lstein**, **@StAlKeR7779** or **@psychedelicious**.
|
||||
|
||||
|
||||
@@ -22,15 +22,15 @@ Before starting these steps, ensure you have your local environment [configured
|
||||
2. Fork the [InvokeAI](https://github.com/invoke-ai/InvokeAI) repository to your GitHub profile. This means that you will have a copy of the repository under **your-GitHub-username/InvokeAI**.
|
||||
3. Clone the repository to your local machine using:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-GitHub-username/InvokeAI.git
|
||||
```
|
||||
```bash
|
||||
git clone https://github.com/your-GitHub-username/InvokeAI.git
|
||||
```
|
||||
|
||||
If you're unfamiliar with using Git through the commandline, [GitHub Desktop](https://desktop.github.com) is a easy-to-use alternative with a UI. You can do all the same steps listed here, but through the interface. 4. Create a new branch for your fix using:
|
||||
|
||||
```bash
|
||||
git checkout -b branch-name-here
|
||||
```
|
||||
```bash
|
||||
git checkout -b branch-name-here
|
||||
```
|
||||
|
||||
5. Make the appropriate changes for the issue you are trying to address or the feature that you want to add.
|
||||
6. Add the file contents of the changed files to the "snapshot" git uses to manage the state of the project, also known as the index:
|
||||
|
||||
@@ -27,9 +27,9 @@ If you just want to use Invoke, you should use the [installer][installer link].
|
||||
|
||||
5. Activate the venv (you'll need to do this every time you want to run the app):
|
||||
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
6. Install the repo as an [editable install][editable install link]:
|
||||
|
||||
@@ -37,7 +37,7 @@ If you just want to use Invoke, you should use the [installer][installer link].
|
||||
pip install -e ".[dev,test,xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
Refer to the [manual installation][manual install link]] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
Refer to the [manual installation][manual install link] instructions for more determining the correct install options. `xformers` is optional, but `dev` and `test` are not.
|
||||
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
|
||||
@@ -34,11 +34,11 @@ Please reach out to @hipsterusername on [Discord](https://discord.gg/ZmtBAhwWhy)
|
||||
|
||||
## Contributors
|
||||
|
||||
This project is a combined effort of dedicated people from across the world. [Check out the list of all these amazing people](https://invoke-ai.github.io/InvokeAI/other/CONTRIBUTORS/). We thank them for their time, hard work and effort.
|
||||
This project is a combined effort of dedicated people from across the world. [Check out the list of all these amazing people](contributors.md). We thank them for their time, hard work and effort.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](https://github.com/invoke-ai/InvokeAI/blob/main/CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
The InvokeAI community is a welcoming place, and we want your help in maintaining that. Please review our [Code of Conduct](../CODE_OF_CONDUCT.md) to learn more - it's essential to maintaining a respectful and inclusive environment.
|
||||
|
||||
By making a contribution to this project, you certify that:
|
||||
|
||||
|
||||
@@ -99,7 +99,6 @@ their descriptions.
|
||||
| Scale Latents | Scales latents by a given factor. |
|
||||
| Segment Anything Processor | Applies segment anything processing to image |
|
||||
| Show Image | Displays a provided image, and passes it forward in the pipeline. |
|
||||
| Step Param Easing | Experimental per-step parameter easing for denoising steps |
|
||||
| String Primitive Collection | A collection of string primitive values |
|
||||
| String Primitive | A string primitive value |
|
||||
| Subtract Integers | Subtracts two numbers |
|
||||
|
||||
@@ -110,7 +110,7 @@ async def cancel_by_batch_ids(
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel_by_destination",
|
||||
operation_id="cancel_by_destination",
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
||||
@@ -59,11 +59,32 @@ logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
|
||||
# the correct port when the server starts in the lifespan handler.
|
||||
port = app_config.port
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||
|
||||
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
|
||||
proto = "https" if app_config.ssl_certfile else "http"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
|
||||
|
||||
# Logging this way ignores the logger's log level and _always_ logs the message
|
||||
record = logger.makeRecord(
|
||||
name=logger.name,
|
||||
level=logging.INFO,
|
||||
fn="",
|
||||
lno=0,
|
||||
msg=msg,
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
logger.handle(record)
|
||||
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
@@ -206,6 +227,7 @@ def invoke_api() -> None:
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
global port
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
@@ -217,18 +239,17 @@ def invoke_api() -> None:
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level,
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
log = InvokeAILogger.get_logger(logname)
|
||||
log.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
log.addHandler(ch)
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ custom_nodes_readme_path = str(custom_nodes_path / "README.md")
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/init.py", custom_nodes_init_path)
|
||||
shutil.copy(Path(__file__).parent / "custom_nodes/README.md", custom_nodes_readme_path)
|
||||
|
||||
# set the same permissions as the destination directory, in case our source is read-only,
|
||||
# so that the files are user-writable
|
||||
for p in custom_nodes_path.glob("**/*"):
|
||||
p.chmod(custom_nodes_path.stat().st_mode)
|
||||
|
||||
# Import custom nodes, see https://docs.python.org/3/library/importlib.html#importing-programmatically
|
||||
spec = spec_from_file_location("custom_nodes", custom_nodes_init_path)
|
||||
if spec is None or spec.loader is None:
|
||||
|
||||
@@ -63,6 +63,7 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
|
||||
- `Deprecated`: The invocation is deprecated and may be removed in a future version.
|
||||
- `Internal`: The invocation is not intended for use by end-users. It may be changed or removed at any time, but is exposed for users to play with.
|
||||
- `Special`: The invocation is a special case and does not fit into any of the other classifications.
|
||||
"""
|
||||
|
||||
Stable = "stable"
|
||||
@@ -70,6 +71,7 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Prototype = "prototype"
|
||||
Deprecated = "deprecated"
|
||||
Internal = "internal"
|
||||
Special = "special"
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
|
||||
@@ -1,98 +1,120 @@
|
||||
from typing import Any, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, LatentsField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def slerp(
|
||||
t: Union[float, np.ndarray],
|
||||
v0: Union[torch.Tensor, np.ndarray],
|
||||
v1: Union[torch.Tensor, np.ndarray],
|
||||
device: torch.device,
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
|
||||
|
||||
@invocation(
|
||||
"lblend",
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
tags=["latents", "blend", "mask"],
|
||||
category="latents",
|
||||
version="1.0.3",
|
||||
version="1.1.0",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
"""Blend two latents using a given alpha. If a mask is provided, the second latents will be masked before blending.
|
||||
Latents must have same size. Masking functionality added by @dwringer."""
|
||||
|
||||
latents_a: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents_b: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
latents_a: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
latents_b: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask for blending in latents B")
|
||||
alpha: float = InputField(ge=0, default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
return mask_tensor
|
||||
|
||||
def replace_tensor_from_masked_tensor(
|
||||
self, tensor: torch.Tensor, other_tensor: torch.Tensor, mask_tensor: torch.Tensor
|
||||
):
|
||||
output = tensor.clone()
|
||||
mask_tensor = mask_tensor.expand(output.shape)
|
||||
if output.dtype != torch.float16:
|
||||
output = torch.add(output, mask_tensor * torch.sub(other_tensor, tensor))
|
||||
else:
|
||||
output = torch.add(output, mask_tensor.half() * torch.sub(other_tensor, tensor))
|
||||
return output
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||
if self.mask is None:
|
||||
mask_tensor = torch.zeros(latents_a.shape[-2:])
|
||||
else:
|
||||
mask_tensor = self.prep_mask_tensor(context.images.get_pil(self.mask.image_name))
|
||||
mask_tensor = tv_resize(mask_tensor, latents_a.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
|
||||
latents_b = self.replace_tensor_from_masked_tensor(latents_b, latents_a, mask_tensor)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
raise ValueError("Latents to blend must be the same size.")
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||
v0 (np.ndarray): Starting vector
|
||||
v1 (np.ndarray): Final vector
|
||||
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||
colineal. Not recommended to alter this.
|
||||
Returns:
|
||||
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||
"""
|
||||
inputs_are_torch = False
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v0 = v0.detach().cpu().numpy()
|
||||
if not isinstance(v1, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
v1 = v1.detach().cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b, device)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
@@ -19,9 +19,9 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@@ -66,10 +66,10 @@ class CompelInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -82,7 +82,7 @@ class CompelInvocation(BaseInvocation):
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_model_patches(
|
||||
model=text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_te_",
|
||||
@@ -162,11 +162,11 @@ class SDXLPromptInvocationBase:
|
||||
c_pooled = None
|
||||
return c, c_pooled
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -179,7 +179,7 @@ class SDXLPromptInvocationBase:
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||
tokenizer_info as tokenizer,
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_model_patches(
|
||||
text_encoder,
|
||||
patches=_lora_loader(),
|
||||
prefix=lora_prefix,
|
||||
|
||||
1563
invokeai/app/invocations/composition-nodes.py
Normal file
1563
invokeai/app/invocations/composition-nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,6 @@ from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
@@ -29,11 +28,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32,
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=4,
|
||||
)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=4)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
|
||||
@@ -7,7 +7,6 @@ from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -76,11 +75,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
ui_order=7,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32,
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=9,
|
||||
)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32, ui_order=9)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
|
||||
@@ -37,10 +37,10 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
@@ -987,10 +987,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@@ -1003,7 +1003,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_model_patches(
|
||||
model=unet,
|
||||
patches=_lora_loader(),
|
||||
prefix="lora_unet_",
|
||||
|
||||
@@ -56,6 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@@ -143,6 +144,7 @@ class FieldDescriptions:
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
control_lora_model = "Control LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
@@ -250,6 +252,11 @@ class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class SD3ConditioningField(BaseModel):
|
||||
|
||||
49
invokeai/app/invocations/flux_control_lora_loader.py
Normal file
49
invokeai/app/invocations/flux_control_lora_loader.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_control_lora_loader_output")
|
||||
class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Flux Control LoRA Loader Output"""
|
||||
|
||||
control_lora: ControlLoRAField = OutputField(
|
||||
title="Flux Control LoRA", description="Control LoRAs to apply on model loading", default=None
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_control_lora_loader",
|
||||
title="Flux Control LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
)
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
weight: float = InputField(description="The weight of the LoRA.", default=1.0)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
|
||||
if not context.models.exists(self.lora.key):
|
||||
raise ValueError(f"Unknown lora: {self.lora.key}!")
|
||||
|
||||
return FluxControlLoRALoaderOutput(
|
||||
control_lora=ControlLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
from typing import Callable, Iterator, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
@@ -21,8 +23,9 @@ from invokeai.app.invocations.fields import (
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@@ -30,6 +33,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
@@ -42,10 +46,11 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -56,7 +61,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.2.1",
|
||||
version="3.2.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -87,10 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
control_lora: Optional[ControlLoRAField] = InputField(
|
||||
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control LoRA", default=None
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_text_conditioning: FluxConditioningField | None = InputField(
|
||||
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
@@ -139,36 +147,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _load_text_conditioning(
|
||||
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load the conditioning data.
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
return t5_embeddings, clip_embeddings
|
||||
|
||||
def _run_diffusion(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
):
|
||||
inference_dtype = torch.bfloat16
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.positive_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
neg_t5_embeddings: torch.Tensor | None = None
|
||||
neg_clip_embeddings: torch.Tensor | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
|
||||
context, self.negative_text_conditioning.conditioning_name, inference_dtype
|
||||
)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
@@ -183,15 +167,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data.
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.positive_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
neg_text_conditionings: list[FluxTextConditioning] | None = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.negative_text_conditioning,
|
||||
packed_height=packed_h,
|
||||
packed_width=packed_w,
|
||||
dtype=inference_dtype,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
|
||||
pos_text_conditionings, img_seq_len=packed_h * packed_w
|
||||
)
|
||||
neg_regional_prompting_extension = (
|
||||
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
|
||||
if neg_text_conditionings
|
||||
else None
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
image_seq_len=packed_h * packed_w,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
@@ -226,30 +240,26 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
if is_schnell and self.control_lora:
|
||||
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
|
||||
|
||||
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
|
||||
img_cond = self._prep_structural_control_img_cond(context)
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
|
||||
pos_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
neg_txt_ids: torch.Tensor | None = None
|
||||
if neg_t5_embeddings is not None:
|
||||
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
|
||||
neg_txt_ids = torch.zeros(
|
||||
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
img_cond = pack(img_cond) if img_cond is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
|
||||
# packed_w correctly.
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
@@ -299,7 +309,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if config.format in [ModelFormat.Checkpoint]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
@@ -314,7 +324,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference,
|
||||
# than directly patching the weights, but is agnostic to the quantization format.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
LayerPatcher.apply_model_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
@@ -338,12 +348,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=pos_t5_embeddings,
|
||||
txt_ids=pos_txt_ids,
|
||||
vec=pos_clip_embeddings,
|
||||
neg_txt=neg_t5_embeddings,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
neg_vec=neg_clip_embeddings,
|
||||
pos_regional_prompting_extension=pos_regional_prompting_extension,
|
||||
neg_regional_prompting_extension=neg_regional_prompting_extension,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
@@ -352,11 +358,49 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: FluxConditioningField | list[FluxConditioningField],
|
||||
packed_height: int,
|
||||
packed_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[FluxTextConditioning]:
|
||||
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
|
||||
# Normalize to a list of FluxConditioningFields.
|
||||
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[FluxTextConditioning] = []
|
||||
for cond_field in cond_list:
|
||||
# Load the text embeddings.
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
flux_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
# Load the mask, if provided.
|
||||
mask: Optional[torch.Tensor] = None
|
||||
if cond_field.mask is not None:
|
||||
mask = context.tensors.load(cond_field.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, packed_height, packed_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
@classmethod
|
||||
def prep_cfg_scale(
|
||||
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
|
||||
@@ -545,6 +589,29 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _prep_structural_control_img_cond(self, context: InvocationContext) -> torch.Tensor | None:
|
||||
if self.control_lora is None:
|
||||
return None
|
||||
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a FLUX Control LoRA.")
|
||||
|
||||
# Load the conditioning image and resize it to the target image size.
|
||||
cond_img = context.images.get_pil(self.control_lora.img.image_name)
|
||||
cond_img = cond_img.convert("RGB")
|
||||
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
|
||||
cond_img = np.array(cond_img)
|
||||
|
||||
# Normalize the conditioning image to the range [-1, 1].
|
||||
# This normalization is based on the original implementations here:
|
||||
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L34
|
||||
# https://github.com/black-forest-labs/flux/blob/805da8571a0b49b6d4043950bd266a65328c243b/src/flux/modules/image_embedders.py#L60
|
||||
img_cond = torch.from_numpy(cond_img).float() / 127.5 - 1.0
|
||||
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
|
||||
|
||||
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
|
||||
if self.ip_adapter is None:
|
||||
return []
|
||||
@@ -651,10 +718,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
|
||||
if self.control_lora:
|
||||
# Note: Since FLUX structural control LoRAs modify the shape of some weights, it is important that they are
|
||||
# applied last.
|
||||
loras.append(self.control_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Tuple
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@@ -22,7 +29,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
title="FLUX Text Encoding",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -41,7 +48,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
@@ -54,7 +64,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
@@ -99,7 +111,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
@@ -118,9 +130,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return pooled_prompt_embeds
|
||||
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
59
invokeai/app/invocations/image_panels.py
Normal file
59
invokeai/app/invocations/image_panels.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("image_panel_coordinate_output")
|
||||
class ImagePanelCoordinateOutput(BaseInvocationOutput):
|
||||
x_left: int = OutputField(description="The left x-coordinate of the panel.")
|
||||
y_top: int = OutputField(description="The top y-coordinate of the panel.")
|
||||
width: int = OutputField(description="The width of the panel.")
|
||||
height: int = OutputField(description="The height of the panel.")
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_panel_layout",
|
||||
title="Image Panel Layout",
|
||||
tags=["image", "panel", "layout"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ImagePanelLayoutInvocation(BaseInvocation):
|
||||
"""Get the coordinates of a single panel in a grid. (If the full image shape cannot be divided evenly into panels,
|
||||
then the grid may not cover the entire image.)
|
||||
"""
|
||||
|
||||
width: int = InputField(description="The width of the entire grid.")
|
||||
height: int = InputField(description="The height of the entire grid.")
|
||||
num_cols: int = InputField(ge=1, default=1, description="The number of columns in the grid.")
|
||||
num_rows: int = InputField(ge=1, default=1, description="The number of rows in the grid.")
|
||||
panel_col_idx: int = InputField(ge=0, default=0, description="The column index of the panel to be processed.")
|
||||
panel_row_idx: int = InputField(ge=0, default=0, description="The row index of the panel to be processed.")
|
||||
|
||||
@field_validator("panel_col_idx")
|
||||
def validate_panel_col_idx(cls, v: int, info: ValidationInfo) -> int:
|
||||
if v < 0 or v >= info.data["num_cols"]:
|
||||
raise ValueError(f"panel_col_idx must be between 0 and {info.data['num_cols'] - 1}")
|
||||
return v
|
||||
|
||||
@field_validator("panel_row_idx")
|
||||
def validate_panel_row_idx(cls, v: int, info: ValidationInfo) -> int:
|
||||
if v < 0 or v >= info.data["num_rows"]:
|
||||
raise ValueError(f"panel_row_idx must be between 0 and {info.data['num_rows'] - 1}")
|
||||
return v
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImagePanelCoordinateOutput:
|
||||
x_left = self.panel_col_idx * (self.width // self.num_cols)
|
||||
y_top = self.panel_row_idx * (self.height // self.num_rows)
|
||||
width = self.width // self.num_cols
|
||||
height = self.height // self.num_rows
|
||||
return ImagePanelCoordinateOutput(x_left=x_left, y_top=y_top, width=width, height=height)
|
||||
@@ -13,7 +13,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -49,7 +49,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(
|
||||
|
||||
@@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
@@ -51,7 +51,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -65,11 +65,6 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class T5EncoderField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
@@ -80,6 +75,15 @@ class VAEField(BaseModel):
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
class ControlLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a UNet field."""
|
||||
|
||||
@@ -1,43 +1,4 @@
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from easing_functions import (
|
||||
BackEaseIn,
|
||||
BackEaseInOut,
|
||||
BackEaseOut,
|
||||
BounceEaseIn,
|
||||
BounceEaseInOut,
|
||||
BounceEaseOut,
|
||||
CircularEaseIn,
|
||||
CircularEaseInOut,
|
||||
CircularEaseOut,
|
||||
CubicEaseIn,
|
||||
CubicEaseInOut,
|
||||
CubicEaseOut,
|
||||
ElasticEaseIn,
|
||||
ElasticEaseInOut,
|
||||
ElasticEaseOut,
|
||||
ExponentialEaseIn,
|
||||
ExponentialEaseInOut,
|
||||
ExponentialEaseOut,
|
||||
LinearInOut,
|
||||
QuadEaseIn,
|
||||
QuadEaseInOut,
|
||||
QuadEaseOut,
|
||||
QuarticEaseIn,
|
||||
QuarticEaseInOut,
|
||||
QuarticEaseOut,
|
||||
QuinticEaseIn,
|
||||
QuinticEaseInOut,
|
||||
QuinticEaseOut,
|
||||
SineEaseIn,
|
||||
SineEaseInOut,
|
||||
SineEaseOut,
|
||||
)
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField
|
||||
@@ -65,191 +26,3 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation(
|
||||
"step_param_easing",
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.2",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||
end_value: float = InputField(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = InputField(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = InputField(default=None, description="value after easing end")
|
||||
mirror: bool = InputField(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("start_step: " + str(start_step))
|
||||
context.logger.debug("end_step: " + str(end_step))
|
||||
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.logger.debug("prelist: " + str(prelist))
|
||||
context.logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
context.logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = []
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1,
|
||||
)
|
||||
base_easing_vals = []
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1,
|
||||
)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
@@ -4,7 +4,13 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoundingBoxField,
|
||||
@@ -533,3 +539,23 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
@@ -16,10 +16,10 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
|
||||
|
||||
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
|
||||
@@ -150,7 +150,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
|
||||
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
|
||||
exit_stack.enter_context(
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
LayerPatcher.apply_model_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context, clip_model),
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
@@ -193,9 +193,9 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
|
||||
def _clip_lora_iterator(
|
||||
self, context: InvocationContext, clip_model: CLIPField
|
||||
) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in clip_model.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -22,8 +22,8 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
@@ -194,10 +194,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
@@ -207,7 +207,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
unet_info as unet,
|
||||
LoRAPatcher.apply_lora_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import filecmp
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
@@ -96,6 +97,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
log_sql: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
|
||||
log_level_network: Log level for network-related messages. 'info' and 'debug' are very verbose.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
use_memory_db: Use in-memory database. Useful for development.
|
||||
dev_reload: Automatically reload when Python sources are changed. Does not reload node definitions.
|
||||
profile_graphs: Enable graph profiling using `cProfile`.
|
||||
@@ -162,6 +164,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
log_format: LOG_FORMAT = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.')
|
||||
log_level: LOG_LEVEL = Field(default="info", description="Emit logging messages at this level or higher.")
|
||||
log_sql: bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.")
|
||||
log_level_network: LOG_LEVEL = Field(default='warning', description="Log level for network-related messages. 'info' and 'debug' are very verbose.")
|
||||
|
||||
# Development
|
||||
use_memory_db: bool = Field(default=False, description="Use in-memory database. Useful for development.")
|
||||
@@ -525,9 +528,35 @@ def get_config() -> InvokeAIAppConfig:
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
# Copy all legacy configs only if needed
|
||||
# We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
dest_path = config.legacy_conf_path
|
||||
|
||||
# Create destination (we don't need to check for existence)
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Compare directories recursively
|
||||
comparison = filecmp.dircmp(configs_src, dest_path)
|
||||
need_copy = any(
|
||||
[
|
||||
comparison.left_only, # Files exist only in source
|
||||
comparison.diff_files, # Files that differ
|
||||
comparison.common_funny, # Files that couldn't be compared
|
||||
]
|
||||
)
|
||||
|
||||
if need_copy:
|
||||
# Get permissions from destination directory
|
||||
dest_mode = dest_path.stat().st_mode
|
||||
|
||||
# Copy directory tree
|
||||
shutil.copytree(configs_src, dest_path, dirs_exist_ok=True)
|
||||
|
||||
# Set permissions on copied files to match destination directory
|
||||
dest_path.chmod(dest_mode)
|
||||
for p in dest_path.glob("**/*"):
|
||||
p.chmod(dest_mode)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
|
||||
@@ -438,9 +438,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
source_stripped = source.strip('"')
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
if Path(source_stripped).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source_stripped))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
|
||||
@@ -86,7 +86,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
@@ -378,6 +378,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
# Make sure the cancel event is for the currently processing queue item
|
||||
if self._queue_item and self._queue_item.item_id != event[1].item_id:
|
||||
return
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
@@ -436,7 +439,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
self._invoker.services.logger.info(
|
||||
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
|
||||
)
|
||||
cancel_event.clear()
|
||||
|
||||
# Run the graph
|
||||
|
||||
@@ -16,6 +16,7 @@ from pydantic import (
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
BatchDataType = Union[StrictStr, float, int, ImageField]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
|
||||
@@ -35,7 +35,7 @@ class Migration11Callback:
|
||||
|
||||
def _remove_convert_cache(self) -> None:
|
||||
"""Rename models/.cache to models/.convert_cache."""
|
||||
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||
self._logger.info("Removing models/.cache directory. Converted models will now be cached in .convert_cache.")
|
||||
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.math import attention
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
|
||||
class CustomDoubleStreamBlockProcessor:
|
||||
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
|
||||
|
||||
@staticmethod
|
||||
def _double_stream_block_forward(
|
||||
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
|
||||
block: DoubleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
|
||||
values.
|
||||
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
|
||||
- IP-Adapter support
|
||||
"""
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
|
||||
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
|
||||
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
|
||||
block, img, txt, vec, pe, attn_mask=attn_mask
|
||||
)
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
for ip_adapter_extension in ip_adapter_extensions:
|
||||
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class CustomSingleStreamBlockProcessor:
|
||||
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
|
||||
etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _single_stream_block_forward(
|
||||
block: SingleStreamBlock,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function is a direct copy of SingleStreamBlock.forward()."""
|
||||
mod, _ = block.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
|
||||
q, k = block.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
@staticmethod
|
||||
def custom_single_block_forward(
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
block_index: int,
|
||||
block: SingleStreamBlock,
|
||||
img: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
pe: torch.Tensor,
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> torch.Tensor:
|
||||
"""A custom implementation of SingleStreamBlock.forward() with additional features:
|
||||
- Masking
|
||||
"""
|
||||
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
|
||||
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
@@ -18,14 +19,8 @@ def denoise(
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
# positive text conditioning
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# negative text conditioning
|
||||
neg_txt: torch.Tensor | None,
|
||||
neg_txt_ids: torch.Tensor | None,
|
||||
neg_vec: torch.Tensor | None,
|
||||
pos_regional_prompting_extension: RegionalPromptingExtension,
|
||||
neg_regional_prompting_extension: RegionalPromptingExtension | None,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
@@ -35,6 +30,8 @@ def denoise(
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
# extra img tokens
|
||||
img_cond: torch.Tensor | None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -61,9 +58,9 @@ def denoise(
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
@@ -74,13 +71,13 @@ def denoise(
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
|
||||
pred = model(
|
||||
img=img,
|
||||
img=pred_img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -88,6 +85,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
@@ -97,15 +95,15 @@ def denoise(
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
@@ -113,6 +111,7 @@ def denoise(
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
|
||||
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
276
invokeai/backend/flux/extensions/regional_prompting_extension.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
class RegionalPromptingExtension:
|
||||
"""A class for managing regional prompting with FLUX.
|
||||
|
||||
This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
restricted_attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
self.regional_text_conditioning = regional_text_conditioning
|
||||
self.restricted_attn_mask = restricted_attn_mask
|
||||
|
||||
def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
||||
order = [self.restricted_attn_mask, None]
|
||||
return order[block_index % len(order)]
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
|
||||
"""Create a RegionalPromptingExtension from a list of text conditionings.
|
||||
|
||||
Args:
|
||||
text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
|
||||
img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
|
||||
"""
|
||||
regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
|
||||
attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
|
||||
regional_text_conditioning, img_seq_len
|
||||
)
|
||||
return cls(
|
||||
regional_text_conditioning=regional_text_conditioning,
|
||||
restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
|
||||
)
|
||||
|
||||
# Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
|
||||
#
|
||||
# @classmethod
|
||||
# def _prepare_unrestricted_attn_mask(
|
||||
# cls,
|
||||
# regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
# img_seq_len: int,
|
||||
# ) -> torch.Tensor:
|
||||
# """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
|
||||
# - img self-attention is not masked.
|
||||
# - img regions attend to both txt within their own region and to global prompts.
|
||||
# """
|
||||
# device = TorchDevice.choose_torch_device()
|
||||
|
||||
# # Infer txt_seq_len from the t5_embeddings tensor.
|
||||
# txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# # Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# # There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# # 1. txt attends to itself
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # 4. regional img attends to itself
|
||||
|
||||
# # Initialize empty attention mask.
|
||||
# regional_attention_mask = torch.zeros(
|
||||
# (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
# )
|
||||
|
||||
# for image_mask, t5_embedding_range in zip(
|
||||
# regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
# ):
|
||||
# # 1. txt attends to itself
|
||||
# regional_attention_mask[
|
||||
# t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
# ] = 1.0
|
||||
|
||||
# # 2. txt attends to corresponding regional img
|
||||
# # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
|
||||
|
||||
# # 3. regional img attends to corresponding txt
|
||||
# # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
# fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
|
||||
# regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
|
||||
|
||||
# # 4. regional img attends to itself
|
||||
# # Allow unrestricted img self attention.
|
||||
# regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
|
||||
|
||||
# # Convert attention mask to boolean.
|
||||
# regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
# return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _prepare_restricted_attn_mask(
|
||||
cls,
|
||||
regional_text_conditioning: FluxRegionalTextConditioning,
|
||||
img_seq_len: int,
|
||||
) -> torch.Tensor | None:
|
||||
"""Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
|
||||
- img self-attention is only allowed within regions.
|
||||
- img regions only attend to txt within their own region, not to global prompts.
|
||||
"""
|
||||
# Identify background region. I.e. the region that is not covered by any region masks.
|
||||
background_region_mask: None | torch.Tensor = None
|
||||
for image_mask in regional_text_conditioning.image_masks:
|
||||
if image_mask is not None:
|
||||
if background_region_mask is None:
|
||||
background_region_mask = torch.ones_like(image_mask)
|
||||
background_region_mask *= 1 - image_mask
|
||||
|
||||
if background_region_mask is None:
|
||||
# There are no region masks, short-circuit and return None.
|
||||
# TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
|
||||
# is a rare use case and would make the logic here significantly more complicated.
|
||||
return None
|
||||
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Infer txt_seq_len from the t5_embeddings tensor.
|
||||
txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
# In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
|
||||
# Concatenation happens in the following order: [txt_seq, img_seq].
|
||||
# There are 4 portions of the attention mask to consider as we prepare it:
|
||||
# 1. txt attends to itself
|
||||
# 2. txt attends to corresponding regional img
|
||||
# 3. regional img attends to corresponding txt
|
||||
# 4. regional img attends to itself
|
||||
|
||||
# Initialize empty attention mask.
|
||||
regional_attention_mask = torch.zeros(
|
||||
(txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
|
||||
)
|
||||
|
||||
for image_mask, t5_embedding_range in zip(
|
||||
regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
|
||||
):
|
||||
# 1. txt attends to itself
|
||||
regional_attention_mask[
|
||||
t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
|
||||
] = 1.0
|
||||
|
||||
if image_mask is not None:
|
||||
# 2. txt attends to corresponding regional img
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
image_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. regional img attends to corresponding txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
image_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# 4. regional img attends to itself
|
||||
image_mask = image_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
|
||||
else:
|
||||
# We don't allow attention between non-background image regions and global prompts. This helps to ensure
|
||||
# that regions focus on their local prompts. We do, however, allow attention between background regions
|
||||
# and global prompts. If we didn't do this, then the background regions would not attend to any txt
|
||||
# embeddings, which we found experimentally to cause artifacts.
|
||||
|
||||
# 2. global txt attends to background region
|
||||
# Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
|
||||
background_region_mask.view(1, img_seq_len)
|
||||
)
|
||||
|
||||
# 3. background region attends to global txt
|
||||
# Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
|
||||
regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
|
||||
background_region_mask.view(img_seq_len, 1)
|
||||
)
|
||||
|
||||
# Allow background regions to attend to themselves.
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
|
||||
regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
|
||||
|
||||
# Convert attention mask to boolean.
|
||||
regional_attention_mask = regional_attention_mask > 0.5
|
||||
|
||||
return regional_attention_mask
|
||||
|
||||
@classmethod
|
||||
def _concat_regional_text_conditioning(
|
||||
cls,
|
||||
text_conditionings: list[FluxTextConditioning],
|
||||
) -> FluxRegionalTextConditioning:
|
||||
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
|
||||
concat_t5_embeddings: list[torch.Tensor] = []
|
||||
concat_t5_embedding_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
|
||||
# Choose global CLIP embedding.
|
||||
# Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
|
||||
# the first prompt's CLIP embedding.
|
||||
global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
|
||||
for text_conditioning in text_conditionings:
|
||||
if text_conditioning.mask is None:
|
||||
global_clip_embedding = text_conditioning.clip_embeddings
|
||||
break
|
||||
|
||||
cur_t5_embedding_len = 0
|
||||
for text_conditioning in text_conditionings:
|
||||
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
|
||||
|
||||
concat_t5_embedding_ranges.append(
|
||||
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
|
||||
)
|
||||
|
||||
image_masks.append(text_conditioning.mask)
|
||||
|
||||
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
|
||||
|
||||
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
|
||||
|
||||
# Initialize the txt_ids tensor.
|
||||
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
|
||||
t5_txt_ids = torch.zeros(
|
||||
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
return FluxRegionalTextConditioning(
|
||||
t5_embeddings=t5_embeddings,
|
||||
clip_embeddings=global_clip_embedding,
|
||||
t5_txt_ids=t5_txt_ids,
|
||||
image_masks=image_masks,
|
||||
t5_embedding_ranges=concat_t5_embedding_ranges,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||
|
||||
packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
|
||||
|
||||
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
|
||||
# Flatten the height and width dimensions into a single image_seq_len dimension.
|
||||
return resized_mask.flatten(start_dim=2)
|
||||
@@ -41,10 +41,12 @@ def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Te
|
||||
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
|
||||
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
|
||||
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
|
||||
clip_extra_context_tokens = state_dict["ip_adapter_proj_model.proj.weight"].shape[0] // context_dim
|
||||
|
||||
return XlabsIpAdapterParams(
|
||||
num_double_blocks=num_double_blocks,
|
||||
context_dim=context_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
clip_embeddings_dim=clip_embeddings_dim,
|
||||
clip_extra_context_tokens=clip_extra_context_tokens,
|
||||
)
|
||||
|
||||
@@ -31,13 +31,16 @@ class XlabsIpAdapterParams:
|
||||
hidden_dim: int
|
||||
|
||||
clip_embeddings_dim: int
|
||||
clip_extra_context_tokens: int
|
||||
|
||||
|
||||
class XlabsIpAdapterFlux(torch.nn.Module):
|
||||
def __init__(self, params: XlabsIpAdapterParams):
|
||||
super().__init__()
|
||||
self.image_proj = ImageProjModel(
|
||||
cross_attention_dim=params.context_dim, clip_embeddings_dim=params.clip_embeddings_dim
|
||||
cross_attention_dim=params.context_dim,
|
||||
clip_embeddings_dim=params.clip_embeddings_dim,
|
||||
clip_extra_context_tokens=params.clip_extra_context_tokens,
|
||||
)
|
||||
self.ip_adapter_double_blocks = IPAdapterDoubleBlocks(
|
||||
num_double_blocks=params.num_double_blocks, context_dim=params.context_dim, hidden_dim=params.hidden_dim
|
||||
|
||||
@@ -5,10 +5,10 @@ from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -24,12 +24,12 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.view(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.view(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
return xq_out.view(*xq.shape).type_as(xq), xk_out.view(*xk.shape).type_as(xk)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
CustomSingleStreamBlockProcessor,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
@@ -31,6 +36,7 @@ class FluxParams:
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
out_channels: Optional[int] = None
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@@ -43,7 +49,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
self.out_channels = params.out_channels or self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
@@ -95,6 +101,7 @@ class Flux(nn.Module):
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
ip_adapter_extensions: list[XLabsIPAdapterExtension],
|
||||
regional_prompting_extension: RegionalPromptingExtension,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -117,7 +124,6 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
assert isinstance(block, DoubleStreamBlock)
|
||||
|
||||
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
@@ -128,6 +134,7 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
ip_adapter_extensions=ip_adapter_extensions,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
@@ -140,7 +147,17 @@ class Flux(nn.Module):
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
assert isinstance(block, SingleStreamBlock)
|
||||
img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
|
||||
timestep_index=timestep_index,
|
||||
total_num_timesteps=total_num_timesteps,
|
||||
block_index=block_index,
|
||||
block=block,
|
||||
img=img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
regional_prompting_extension=regional_prompting_extension,
|
||||
)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
@@ -66,10 +66,7 @@ class RMSNorm(torch.nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
return torch.nn.functional.rms_norm(x, self.scale.shape, self.scale, eps=1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
|
||||
36
invokeai/backend/flux/text_conditioning.py
Normal file
36
invokeai/backend/flux/text_conditioning.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxTextConditioning:
|
||||
t5_embeddings: torch.Tensor
|
||||
clip_embeddings: torch.Tensor
|
||||
# If mask is None, the prompt is a global prompt.
|
||||
mask: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxRegionalTextConditioning:
|
||||
# Concatenated text embeddings.
|
||||
# Shape: (1, concatenated_txt_seq_len, 4096)
|
||||
t5_embeddings: torch.Tensor
|
||||
# Shape: (1, concatenated_txt_seq_len, 3)
|
||||
t5_txt_ids: torch.Tensor
|
||||
|
||||
# Global CLIP embeddings.
|
||||
# Shape: (1, 768)
|
||||
clip_embeddings: torch.Tensor
|
||||
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
|
||||
# global prompt.
|
||||
# image_masks[i] is the mask for the ith prompt.
|
||||
# image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
|
||||
image_masks: list[torch.Tensor | None]
|
||||
|
||||
# List of ranges that represent the embedding ranges for each mask.
|
||||
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
|
||||
t5_embedding_ranges: list[Range]
|
||||
BIN
invokeai/backend/image_util/assets/CIELab_to_UPLab.icc
Normal file
BIN
invokeai/backend/image_util/assets/CIELab_to_UPLab.icc
Normal file
Binary file not shown.
1020
invokeai/backend/image_util/composition.py
Normal file
1020
invokeai/backend/image_util/composition.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,11 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer]
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
|
||||
|
||||
class ConcatenatedLoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
concatenated_lora_layer: ConcatenatedLoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._concatenated_lora_layer = concatenated_lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in self._concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= self._weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert self._concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._concatenated_lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class LoRALinearSidecarLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_layer: LoRALayer,
|
||||
weight: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._lora_layer = lora_layer
|
||||
self._weight = weight
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.down)
|
||||
if self._lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, self._lora_layer.up, bias=self._lora_layer.bias)
|
||||
x *= self._weight * self._lora_layer.scale()
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._lora_layer.to(device=device, dtype=dtype)
|
||||
return self
|
||||
@@ -1,24 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoRASidecarModule(torch.nn.Module):
|
||||
"""A LoRA sidecar module that wraps an original module and adds LoRA layers to it."""
|
||||
|
||||
def __init__(self, orig_module: torch.nn.Module, lora_layers: list[torch.nn.Module]):
|
||||
super().__init__()
|
||||
self.orig_module = orig_module
|
||||
self._lora_layers = lora_layers
|
||||
|
||||
def add_lora_layer(self, lora_layer: torch.nn.Module):
|
||||
self._lora_layers.append(lora_layer)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
x = self.orig_module(input)
|
||||
for lora_layer in self._lora_layers:
|
||||
x += lora_layer(input)
|
||||
return x
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self._orig_module.to(device=device, dtype=dtype)
|
||||
for lora_layer in self._lora_layers:
|
||||
lora_layer.to(device=device, dtype=dtype)
|
||||
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlLoRa = "control_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@@ -273,6 +274,36 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
@@ -304,12 +335,6 @@ class VAEDiffusersConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
@@ -535,6 +560,8 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
|
||||
@@ -9,14 +9,6 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@@ -28,10 +20,25 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.Diffusers)
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
@@ -75,7 +82,10 @@ class LoRALoader(ModelLoader):
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
|
||||
@@ -15,9 +15,9 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
@@ -43,7 +43,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
(
|
||||
TextualInversionModelRaw,
|
||||
IPAdapter,
|
||||
LoRAModelRaw,
|
||||
ModelPatchRaw,
|
||||
SpandrelImageToImageModel,
|
||||
GroundingDinoPipeline,
|
||||
SegmentAnythingPipeline,
|
||||
|
||||
@@ -15,10 +15,6 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@@ -43,6 +39,13 @@ from invokeai.backend.model_manager.util.model_util import (
|
||||
lora_token_vector_length,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
)
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
@@ -199,8 +202,8 @@ class ModelProbe(object):
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
|
||||
if not fields["default_settings"]:
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
|
||||
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
|
||||
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
|
||||
elif fields["type"] is ModelType.Main:
|
||||
fields["default_settings"] = get_default_settings_main(fields["base"])
|
||||
|
||||
@@ -258,6 +261,9 @@ class ModelProbe(object):
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt):
|
||||
return ModelType.ControlLoRa
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
(
|
||||
@@ -469,7 +475,7 @@ class ModelProbe(object):
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
@@ -485,6 +491,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"lineart anime": "lineart_anime_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"soft": "hed_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"hed": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
@@ -496,7 +503,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
}
|
||||
|
||||
|
||||
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||
model_name_lower = model_name.lower()
|
||||
if k in model_name_lower:
|
||||
@@ -623,8 +630,10 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return ModelFormat.LyCORIS
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if is_state_dict_likely_in_flux_kohya_format(self.checkpoint) or is_state_dict_likely_in_flux_diffusers_format(
|
||||
self.checkpoint
|
||||
if (
|
||||
is_state_dict_likely_in_flux_kohya_format(self.checkpoint)
|
||||
or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint)
|
||||
or is_state_dict_likely_flux_control(self.checkpoint)
|
||||
):
|
||||
return BaseModelType.Flux
|
||||
|
||||
@@ -1033,6 +1042,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
@@ -1045,6 +1055,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
||||
@@ -298,13 +298,12 @@ ip_adapter_sdxl = StarterModel(
|
||||
previous_names=["IP Adapter SDXL"],
|
||||
)
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter)",
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/resolve/main/ip_adapter.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
previous_names=["XLabs FLUX IP-Adapter"],
|
||||
)
|
||||
# endregion
|
||||
# region ControlNet
|
||||
@@ -489,6 +488,22 @@ union_cnet_flux = StarterModel(
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
# endregion
|
||||
# region Control LoRA
|
||||
flux_canny_control_lora = StarterModel(
|
||||
name="Hard Edge Detection (Canny)",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Canny-dev-lora::flux1-canny-dev-lora.safetensors",
|
||||
description="Uses detected edges in the image to control composition.",
|
||||
type=ModelType.ControlLoRa,
|
||||
)
|
||||
flux_depth_control_lora = StarterModel(
|
||||
name="Depth Map",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Depth-dev-lora::flux1-depth-dev-lora.safetensors",
|
||||
description="Uses depth information in the image to control the depth in the generation.",
|
||||
type=ModelType.ControlLoRa,
|
||||
)
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
t2i_canny_sd1 = StarterModel(
|
||||
name="Hard Edge Detection (canny)",
|
||||
@@ -631,6 +646,8 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
tile_sdxl,
|
||||
union_cnet_sdxl,
|
||||
union_cnet_flux,
|
||||
flux_canny_control_lora,
|
||||
flux_depth_control_lora,
|
||||
t2i_canny_sd1,
|
||||
t2i_sketch_sd1,
|
||||
t2i_depth_sd1,
|
||||
@@ -689,6 +706,8 @@ flux_bundle: list[StarterModel] = [
|
||||
clip_l_encoder,
|
||||
union_cnet_flux,
|
||||
ip_adapter_flux,
|
||||
flux_canny_control_lora,
|
||||
flux_depth_control_lora,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
|
||||
|
||||
@@ -44,7 +44,7 @@ def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, torch.Tensor]:
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
path_str = path.as_posix() if isinstance(path, Path) else path
|
||||
@@ -52,16 +52,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
elif str(path).endswith(".gguf"):
|
||||
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
if scan:
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||
if str(path).endswith(".gguf"):
|
||||
# The GGUF reader used here uses numpy memmap, so these tensors are not loaded into memory during this function
|
||||
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,14 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
|
||||
|
||||
@@ -176,180 +173,3 @@ class ModelPatcher:
|
||||
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
|
||||
if did_apply_freeu:
|
||||
unet.disable_freeu()
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
# based on
|
||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
from invokeai.backend.models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_weights = {}
|
||||
|
||||
try:
|
||||
blended_loras: Dict[str, torch.Tensor] = {}
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
# TODO: rewrite to pass original tensor weight(required by ia3)
|
||||
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
||||
if layer_key in blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
node_names = {}
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
for layer_key, lora_weight in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_names or gemm_key in node_names:
|
||||
if conv_key in node_names:
|
||||
conv_node = model.nodes[node_names[conv_key]]
|
||||
else:
|
||||
conv_node = model.nodes[node_names[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
orig_weight = model.tensors[weight_name]
|
||||
|
||||
if orig_weight.shape[-2:] == (1, 1):
|
||||
if lora_weight.shape[-2:] == (1, 1):
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||
else:
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||
|
||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||
else:
|
||||
if orig_weight.shape != lora_weight.shape:
|
||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||
else:
|
||||
new_weight = orig_weight + lora_weight
|
||||
|
||||
orig_weights[weight_name] = orig_weight
|
||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
elif matmul_key in node_names:
|
||||
weight_node = model.nodes[node_names[matmul_key]]
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
orig_weight = model.tensors[matmul_name]
|
||||
new_weight = orig_weight + lora_weight.transpose()
|
||||
|
||||
orig_weights[matmul_name] = orig_weight
|
||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for name, orig_weight in orig_weights.items():
|
||||
model.tensors[name] = orig_weight
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Tuple[str, Any]],
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
from invokeai.backend.models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||
# exiting this `apply_ti(...)` context manager.
|
||||
#
|
||||
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
||||
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti_name: str, index: int) -> str:
|
||||
trigger = ti_name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti_name, ti in ti_list:
|
||||
if ti.embedding_2 is not None:
|
||||
ti_embedding = (
|
||||
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
||||
)
|
||||
else:
|
||||
ti_embedding = ti.embedding
|
||||
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
embeddings = np.concatenate(
|
||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti_name, _ in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
embedding = ti_embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti_name, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
||||
f" {embedding.shape[0]}, but the current model has token dimension"
|
||||
f" {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
||||
orig_embeddings.dtype
|
||||
)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if orig_embeddings is not None:
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||
|
||||
22
invokeai/backend/patches/layers/base_layer_patch.py
Normal file
22
invokeai/backend/patches/layers/base_layer_patch.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseLayerPatch(ABC):
|
||||
@abstractmethod
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
"""Get the parameter residual updates that should be applied to the original parameters. Parameters omitted
|
||||
from the returned dict are not updated.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
"""Move all internal tensors to the specified device and dtype."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def calc_size(self) -> int:
|
||||
"""Calculate the total size of all internal tensors in bytes."""
|
||||
...
|
||||
@@ -2,8 +2,8 @@ from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
@@ -20,7 +20,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
|
||||
self.lora_layers = lora_layers
|
||||
self.concat_axis = concat_axis
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
19
invokeai/backend/patches/layers/flux_control_lora_layer.py
Normal file
19
invokeai/backend/patches/layers/flux_control_lora_layer.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
|
||||
|
||||
class FluxControlLoRALayer(LoRALayer):
|
||||
"""A special case of LoRALayer for use with FLUX Control LoRAs that pads the target parameter with zeros if the
|
||||
shapes don't match.
|
||||
"""
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
"""This overrides the base class behavior to skip the reshaping step."""
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
return params
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class FullLayer(LoRALayerBase):
|
||||
cls.warn_on_unhandled_keys(values=values, handled_keys={"diff", "diff_b"})
|
||||
return layer
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@@ -16,7 +16,7 @@ class IA3Layer(LoRALayerBase):
|
||||
self.weight = weight
|
||||
self.on_input = on_input
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class LoHALayer(LoRALayerBase):
|
||||
self.t2 = t2
|
||||
assert (self.t1 is None) == (self.t2 is None)
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return self.w1_b.shape[0]
|
||||
|
||||
@classmethod
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
assert (self.w2 is None) != (self.w2_a is None)
|
||||
assert (self.w2_a is None) == (self.w2_b is None)
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
if self.w1_b is not None:
|
||||
return self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
return layer
|
||||
|
||||
def rank(self) -> int:
|
||||
def _rank(self) -> int:
|
||||
return self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1,12 +1,13 @@
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensors_size
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
class LoRALayerBase(BaseLayerPatch):
|
||||
"""Base class for all LoRA-like patching layers."""
|
||||
|
||||
# Note: It is tempting to make this a torch.nn.Module sub-class and make all tensors 'torch.nn.Parameter's. Then we
|
||||
@@ -23,6 +24,7 @@ class LoRALayerBase:
|
||||
def _parse_bias(
|
||||
cls, bias_indices: torch.Tensor | None, bias_values: torch.Tensor | None, bias_size: torch.Tensor | None
|
||||
) -> torch.Tensor | None:
|
||||
"""Helper function to parse a bias tensor from a state dict in LyCORIS format."""
|
||||
assert (bias_indices is None) == (bias_values is None) == (bias_size is None)
|
||||
|
||||
bias = None
|
||||
@@ -37,11 +39,14 @@ class LoRALayerBase:
|
||||
) -> float | None:
|
||||
return alpha.item() if alpha is not None else None
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
"""Return the rank of the LoRA-like layer. Or None if the layer does not have a rank. This value is used to
|
||||
calculate the scale.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def scale(self) -> float:
|
||||
rank = self.rank()
|
||||
rank = self._rank()
|
||||
if self._alpha is None or rank is None:
|
||||
return 1.0
|
||||
return self._alpha / rank
|
||||
@@ -52,15 +57,23 @@ class LoRALayerBase:
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
scale = self.scale()
|
||||
params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
params["bias"] = bias * (weight * scale)
|
||||
|
||||
# Reshape all params to match the original module's shape.
|
||||
for param_name, param_weight in params.items():
|
||||
orig_param = orig_module.get_parameter(param_name)
|
||||
if param_weight.shape != orig_param.shape:
|
||||
params[param_name] = param_weight.reshape(orig_param.shape)
|
||||
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def warn_on_unhandled_keys(cls, values: Dict[str, torch.Tensor], handled_keys: Set[str]):
|
||||
def warn_on_unhandled_keys(cls, values: dict[str, torch.Tensor], handled_keys: set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
unknown_keys = set(values.keys()) - handled_keys
|
||||
if unknown_keys:
|
||||
@@ -2,7 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class NormLayer(LoRALayerBase):
|
||||
cls.warn_on_unhandled_keys(values, {"w_norm", "b_norm"})
|
||||
return layer
|
||||
|
||||
def rank(self) -> int | None:
|
||||
def _rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
27
invokeai/backend/patches/layers/set_parameter_layer.py
Normal file
27
invokeai/backend/patches/layers/set_parameter_layer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class SetParameterLayer(BaseLayerPatch):
|
||||
"""A layer that sets a single parameter to a new target value.
|
||||
(The diff between the target value and current value is calculated internally.)
|
||||
"""
|
||||
|
||||
def __init__(self, param_name: str, weight: torch.Tensor):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.param_name = param_name
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
|
||||
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
|
||||
# Control LoRA implementation.
|
||||
diff = self.weight - orig_module.get_parameter(self.param_name)
|
||||
return {self.param_name: diff}
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return calc_tensor_size(self.weight)
|
||||
@@ -2,16 +2,16 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.full_layer import FullLayer
|
||||
from invokeai.backend.lora.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.full_layer import FullLayer
|
||||
from invokeai.backend.patches.layers.ia3_layer import IA3Layer
|
||||
from invokeai.backend.patches.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.patches.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.norm_layer import NormLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> BaseLayerPatch:
|
||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
# guidance_in.in_layer.lora_B.bias
|
||||
# single_blocks.0.linear1.lora_A.weight
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)"
|
||||
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
|
||||
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
|
||||
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
|
||||
|
||||
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
|
||||
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)
|
||||
lora_b_bias = state_dict.get("img_in.lora_B.bias", None)
|
||||
lora_b_weight = state_dict.get("img_in.lora_B.weight", None)
|
||||
|
||||
return (
|
||||
all_keys_match
|
||||
and lora_a_weight is not None
|
||||
and lora_b_bias is not None
|
||||
and lora_b_weight is not None
|
||||
and lora_a_weight.shape[1] == 128
|
||||
and lora_b_weight.shape[0] == 3072
|
||||
and lora_b_bias.shape[0] == 3072
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
key_props = key.split(".")
|
||||
layer_prop_size = -2 if any(prop in key for prop in ["lora_B", "lora_A"]) else -1
|
||||
layer_name = ".".join(key_props[:layer_prop_size])
|
||||
param_name = ".".join(key_props[layer_prop_size:])
|
||||
if layer_name not in grouped_state_dict:
|
||||
grouped_state_dict[layer_name] = {}
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, layer_state_dict in grouped_state_dict.items():
|
||||
prefixed_key = f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"
|
||||
if layer_key == "img_in":
|
||||
# img_in is a special case because it changes the shape of the original weight.
|
||||
layers[prefixed_key] = FluxControlLoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
)
|
||||
elif all(k in layer_state_dict for k in ["lora_A.weight", "lora_B.bias", "lora_B.weight"]):
|
||||
layers[prefixed_key] = LoRALayer(
|
||||
layer_state_dict["lora_B.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"],
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
else:
|
||||
raise ValueError(f"{layer_key} not expected")
|
||||
|
||||
return ModelPatchRaw(layers=layers)
|
||||
@@ -2,11 +2,11 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
|
||||
@@ -30,7 +30,9 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
|
||||
def lora_model_from_flux_diffusers_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], alpha: float | None
|
||||
) -> ModelPatchRaw:
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
@@ -49,7 +51,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
mlp_ratio = 4.0
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
@@ -215,7 +217,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
|
||||
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
|
||||
|
||||
return LoRAModelRaw(layers=layers_with_prefix)
|
||||
return ModelPatchRaw(layers=layers_with_prefix)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
@@ -3,10 +3,13 @@ from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
|
||||
FLUX_LORA_CLIP_PREFIX,
|
||||
FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
)
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format.
|
||||
# Example keys:
|
||||
@@ -36,7 +39,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
for key, value in state_dict.items():
|
||||
@@ -61,14 +64,14 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
|
||||
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, layer_state_dict in transformer_grouped_sd.items():
|
||||
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
for layer_key, layer_state_dict in clip_grouped_sd.items():
|
||||
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -2,19 +2,19 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
def lora_model_from_sd_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_state(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, values in grouped_state_dict.items():
|
||||
layers[layer_key] = any_lora_layer_from_state_dict(values)
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
@@ -3,20 +3,17 @@ from typing import Mapping, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def __init__(self, layers: Mapping[str, AnyLoRALayer]):
|
||||
class ModelPatchRaw(RawModel):
|
||||
def __init__(self, layers: Mapping[str, BaseLayerPatch]):
|
||||
self.layers = layers
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
for _key, layer in self.layers.items():
|
||||
for layer in self.layers.values():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
return sum(layer.calc_size() for layer in self.layers.values())
|
||||
@@ -3,26 +3,23 @@ from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.sidecar_layers.concatenated_lora.concatenated_lora_linear_sidecar_layer import (
|
||||
ConcatenatedLoRALinearSidecarLayer,
|
||||
)
|
||||
from invokeai.backend.lora.sidecar_layers.lora.lora_linear_sidecar_layer import LoRALinearSidecarLayer
|
||||
from invokeai.backend.lora.sidecar_layers.lora_sidecar_module import LoRASidecarModule
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAPatcher:
|
||||
class LayerPatcher:
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_patches(
|
||||
def apply_model_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
@@ -40,7 +37,7 @@ class LoRAPatcher:
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
LayerPatcher.apply_model_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -52,14 +49,15 @@ class LoRAPatcher:
|
||||
yield
|
||||
finally:
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
cur_param = model.get_parameter(param_key)
|
||||
cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def apply_lora_patch(
|
||||
def apply_model_patch(
|
||||
model: torch.nn.Module,
|
||||
prefix: str,
|
||||
patch: LoRAModelRaw,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
@@ -87,46 +85,70 @@ class LoRAPatcher:
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
LayerPatcher._apply_model_layer_patch(
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_weights=original_weights,
|
||||
)
|
||||
|
||||
layer_scale = layer.scale()
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_patch(
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: BaseLayerPatch,
|
||||
patch_weight: float,
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
patch.to(device=device)
|
||||
patch.to(dtype=torch.float32)
|
||||
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||
param_key = module_key + "." + param_name
|
||||
module_param = module.get_parameter(param_name)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items():
|
||||
param_key = module_to_patch_key + "." + param_name
|
||||
module_param = module_to_patch.get_parameter(param_name)
|
||||
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
# Save original weight
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||
# HACK(ryand): This condition is only necessary to handle layers in FLUX control LoRAs that change the
|
||||
# shape of the original layer.
|
||||
if module_param.nelement() != param_weight.nelement():
|
||||
assert isinstance(patch, FluxControlLoRALayer)
|
||||
expanded_weight = pad_with_zeros(module_param, param_weight.shape)
|
||||
setattr(
|
||||
module_to_patch,
|
||||
param_name,
|
||||
torch.nn.Parameter(expanded_weight, requires_grad=module_param.requires_grad),
|
||||
)
|
||||
module_param = expanded_weight
|
||||
|
||||
lora_param_weight *= patch_weight * layer_scale
|
||||
module_param += lora_param_weight.to(dtype=dtype)
|
||||
module_param += param_weight.to(dtype=dtype)
|
||||
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
patch.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
@contextmanager
|
||||
def apply_lora_sidecar_patches(
|
||||
def apply_model_sidecar_patches(
|
||||
model: torch.nn.Module,
|
||||
patches: Iterable[Tuple[LoRAModelRaw, float]],
|
||||
patches: Iterable[Tuple[ModelPatchRaw, float]],
|
||||
prefix: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
@@ -147,7 +169,7 @@ class LoRAPatcher:
|
||||
original_modules: dict[str, torch.nn.Module] = {}
|
||||
try:
|
||||
for patch, patch_weight in patches:
|
||||
LoRAPatcher._apply_lora_sidecar_patch(
|
||||
LayerPatcher._apply_model_sidecar_patch(
|
||||
model=model,
|
||||
prefix=prefix,
|
||||
patch=patch,
|
||||
@@ -160,14 +182,14 @@ class LoRAPatcher:
|
||||
# Restore original modules.
|
||||
# Note: This logic assumes no nested modules in original_modules.
|
||||
for module_key, orig_module in original_modules.items():
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_key)
|
||||
parent_module = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
LayerPatcher._set_submodule(parent_module, module_name, orig_module)
|
||||
|
||||
@staticmethod
|
||||
def _apply_lora_sidecar_patch(
|
||||
def _apply_model_sidecar_patch(
|
||||
model: torch.nn.Module,
|
||||
patch: LoRAModelRaw,
|
||||
patch: ModelPatchRaw,
|
||||
patch_weight: float,
|
||||
prefix: str,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
@@ -190,32 +212,50 @@ class LoRAPatcher:
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key, module = LoRAPatcher._get_submodule(
|
||||
module_key, module = LayerPatcher._get_submodule(
|
||||
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
|
||||
)
|
||||
|
||||
# Initialize the LoRA sidecar layer.
|
||||
lora_sidecar_layer = LoRAPatcher._initialize_lora_sidecar_layer(module, layer, patch_weight)
|
||||
LayerPatcher._apply_model_layer_wrapper_patch(
|
||||
model=model,
|
||||
module_to_patch=module,
|
||||
module_to_patch_key=module_key,
|
||||
patch=layer,
|
||||
patch_weight=patch_weight,
|
||||
original_modules=original_modules,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Replace the original module with a LoRASidecarModule if it has not already been done.
|
||||
if module_key in original_modules:
|
||||
# The module has already been patched with a LoRASidecarModule. Append to it.
|
||||
assert isinstance(module, LoRASidecarModule)
|
||||
lora_sidecar_module = module
|
||||
else:
|
||||
# The module has not yet been patched with a LoRASidecarModule. Create one.
|
||||
lora_sidecar_module = LoRASidecarModule(module, [])
|
||||
original_modules[module_key] = module
|
||||
module_parent_key, module_name = LoRAPatcher._split_parent_key(module_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LoRAPatcher._set_submodule(module_parent, module_name, lora_sidecar_module)
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _apply_model_layer_wrapper_patch(
|
||||
model: torch.nn.Module,
|
||||
module_to_patch: torch.nn.Module,
|
||||
module_to_patch_key: str,
|
||||
patch: BaseLayerPatch,
|
||||
patch_weight: float,
|
||||
original_modules: dict[str, torch.nn.Module],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Apply a single LoRA wrapper patch to a model."""
|
||||
# Replace the original module with a BaseSidecarWrapper if it has not already been done.
|
||||
if not isinstance(module_to_patch, BaseSidecarWrapper):
|
||||
wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch)
|
||||
original_modules[module_to_patch_key] = module_to_patch
|
||||
module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key)
|
||||
module_parent = model.get_submodule(module_parent_key)
|
||||
LayerPatcher._set_submodule(module_parent, module_name, wrapped_module)
|
||||
else:
|
||||
assert module_to_patch_key in original_modules
|
||||
wrapped_module = module_to_patch
|
||||
|
||||
# Move the LoRA sidecar layer to the same device/dtype as the orig module.
|
||||
# TODO(ryand): Experiment with moving to the device first, then casting. This could be faster.
|
||||
lora_sidecar_layer.to(device=lora_sidecar_module.orig_module.weight.device, dtype=dtype)
|
||||
# Move the LoRA layer to the same device/dtype as the orig module.
|
||||
first_param = next(module_to_patch.parameters())
|
||||
device = first_param.device
|
||||
patch.to(device=device, dtype=dtype)
|
||||
|
||||
# Add the LoRA sidecar layer to the LoRASidecarModule.
|
||||
lora_sidecar_module.add_lora_layer(lora_sidecar_layer)
|
||||
# Add the patch to the sidecar wrapper.
|
||||
wrapped_module.add_patch(patch, patch_weight)
|
||||
|
||||
@staticmethod
|
||||
def _split_parent_key(module_key: str) -> tuple[str, str]:
|
||||
@@ -235,21 +275,6 @@ class LoRAPatcher:
|
||||
else:
|
||||
raise ValueError(f"Invalid module key: {module_key}")
|
||||
|
||||
@staticmethod
|
||||
def _initialize_lora_sidecar_layer(orig_layer: torch.nn.Module, lora_layer: AnyLoRALayer, patch_weight: float):
|
||||
# TODO(ryand): Add support for more original layer types and LoRA layer types.
|
||||
if isinstance(orig_layer, torch.nn.Linear) or (
|
||||
isinstance(orig_layer, LoRASidecarModule) and isinstance(orig_layer.orig_module, torch.nn.Linear)
|
||||
):
|
||||
if isinstance(lora_layer, LoRALayer):
|
||||
return LoRALinearSidecarLayer(lora_layer=lora_layer, weight=patch_weight)
|
||||
elif isinstance(lora_layer, ConcatenatedLoRALayer):
|
||||
return ConcatenatedLoRALinearSidecarLayer(concatenated_lora_layer=lora_layer, weight=patch_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported Linear LoRA layer type: {type(lora_layer)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type: {type(orig_layer)}")
|
||||
|
||||
@staticmethod
|
||||
def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
|
||||
try:
|
||||
9
invokeai/backend/patches/pad_with_zeros.py
Normal file
9
invokeai/backend/patches/pad_with_zeros.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
def pad_with_zeros(orig_weight: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
|
||||
"""Pad a weight tensor with zeros to match the target shape."""
|
||||
expanded_weight = torch.zeros(target_shape, dtype=orig_weight.dtype, device=orig_weight.device)
|
||||
slices = tuple(slice(0, dim) for dim in orig_weight.shape)
|
||||
expanded_weight[slices] = orig_weight
|
||||
return expanded_weight
|
||||
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
|
||||
|
||||
class BaseSidecarWrapper(torch.nn.Module):
|
||||
"""A base class for sidecar wrappers.
|
||||
|
||||
A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a
|
||||
list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying
|
||||
the original module.
|
||||
|
||||
Sidecar wrappers are typically used over regular patches when:
|
||||
- The original module is quantized and so the weights can't be patched in the usual way.
|
||||
- The original module is on the CPU and modifying the weights would require backing up the original weights and
|
||||
doubling the CPU memory usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self._orig_module = orig_module
|
||||
self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights
|
||||
|
||||
@property
|
||||
def orig_module(self) -> torch.nn.Module:
|
||||
return self._orig_module
|
||||
|
||||
def add_patch(self, patch: BaseLayerPatch, patch_weight: float):
|
||||
"""Add a patch to the sidecar wrapper."""
|
||||
self._patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
def _aggregate_patch_parameters(
|
||||
self, patches_and_weights: list[tuple[BaseLayerPatch, float]]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Helper function that aggregates the parameters from all patches into a single dict."""
|
||||
params: dict[str, torch.Tensor] = {}
|
||||
|
||||
for patch, patch_weight in patches_and_weights:
|
||||
# TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original
|
||||
# module, this might fail or return incorrect results.
|
||||
layer_params = patch.get_parameters(self._orig_module, weight=patch_weight)
|
||||
|
||||
for param_name, param_weight in layer_params.items():
|
||||
if param_name not in params:
|
||||
params[param_name] = param_weight
|
||||
else:
|
||||
params[param_name] += param_weight
|
||||
|
||||
return params
|
||||
|
||||
def forward(self, *args, **kwargs): # type: ignore
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class Conv1dSidecarWrapper(BaseSidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class Conv2dSidecarWrapper(BaseSidecarWrapper):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights)
|
||||
return self.orig_module(input) + torch.nn.functional.conv1d(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class FluxRMSNormSidecarWrapper(BaseSidecarWrapper):
|
||||
"""A sidecar wrapper for a FLUX RMSNorm layer.
|
||||
|
||||
This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite
|
||||
the RMSNorm scale parameters.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Given the narrow focus of this wrapper, we only support a very particular patch configuration:
|
||||
assert len(self._patches_and_weights) == 1
|
||||
patch, _patch_weight = self._patches_and_weights[0]
|
||||
assert isinstance(patch, SetParameterLayer)
|
||||
assert patch.param_name == "scale"
|
||||
|
||||
# Apply the patch.
|
||||
# NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should
|
||||
# be handled.
|
||||
return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6)
|
||||
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper
|
||||
|
||||
|
||||
class LinearSidecarWrapper(BaseSidecarWrapper):
|
||||
def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear LoRALayer."""
|
||||
x = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x = torch.nn.functional.linear(x, lora_layer.mid)
|
||||
x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias)
|
||||
x *= lora_weight * lora_layer.scale()
|
||||
return x
|
||||
|
||||
def _concatenated_lora_forward(
|
||||
self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float
|
||||
) -> torch.Tensor:
|
||||
"""An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer."""
|
||||
x_chunks: list[torch.Tensor] = []
|
||||
for lora_layer in concatenated_lora_layer.lora_layers:
|
||||
x_chunk = torch.nn.functional.linear(input, lora_layer.down)
|
||||
if lora_layer.mid is not None:
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid)
|
||||
x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias)
|
||||
x_chunk *= lora_weight * lora_layer.scale()
|
||||
x_chunks.append(x_chunk)
|
||||
|
||||
# TODO(ryand): Generalize to support concat_axis != 0.
|
||||
assert concatenated_lora_layer.concat_axis == 0
|
||||
x = torch.cat(x_chunks, dim=-1)
|
||||
return x
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# First, apply the original linear layer.
|
||||
# NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which
|
||||
# change the linear layer's in_features.
|
||||
orig_input = input
|
||||
input = orig_input[..., : self.orig_module.in_features]
|
||||
output = self.orig_module(input)
|
||||
|
||||
# Then, apply layers for which we have optimized implementations.
|
||||
unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = []
|
||||
for patch, patch_weight in self._patches_and_weights:
|
||||
if isinstance(patch, FluxControlLoRALayer):
|
||||
# Note that we use the original input here, not the sliced input.
|
||||
output += self._lora_forward(orig_input, patch, patch_weight)
|
||||
elif isinstance(patch, LoRALayer):
|
||||
output += self._lora_forward(input, patch, patch_weight)
|
||||
elif isinstance(patch, ConcatenatedLoRALayer):
|
||||
output += self._concatenated_lora_forward(input, patch, patch_weight)
|
||||
else:
|
||||
unprocessed_patches_and_weights.append((patch, patch_weight))
|
||||
|
||||
# Finally, apply any remaining patches.
|
||||
if len(unprocessed_patches_and_weights) > 0:
|
||||
aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights)
|
||||
output += torch.nn.functional.linear(
|
||||
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
|
||||
)
|
||||
|
||||
return output
|
||||
20
invokeai/backend/patches/sidecar_wrappers/utils.py
Normal file
20
invokeai/backend/patches/sidecar_wrappers/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper
|
||||
from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper
|
||||
|
||||
|
||||
def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module:
|
||||
if isinstance(orig_module, torch.nn.Linear):
|
||||
return LinearSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv1d):
|
||||
return Conv1dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, torch.nn.Conv2d):
|
||||
return Conv2dSidecarWrapper(orig_module)
|
||||
elif isinstance(orig_module, RMSNorm):
|
||||
return FluxRMSNormSidecarWrapper(orig_module)
|
||||
else:
|
||||
raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}")
|
||||
@@ -52,6 +52,7 @@ GGML_TENSOR_OP_TABLE = {
|
||||
torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore
|
||||
torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore
|
||||
}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.patches.model_patcher import LayerPatcher
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -30,8 +30,8 @@ class LoRAExt(ExtensionBase):
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
LoRAPatcher.apply_lora_patch(
|
||||
assert isinstance(lora_model, ModelPatchRaw)
|
||||
LayerPatcher.apply_model_patch(
|
||||
model=unet,
|
||||
prefix="lora_unet_",
|
||||
patch=lora_model,
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Invoke UI
|
||||
|
||||
<https://invoke-ai.github.io/InvokeAI/contributing/frontend/OVERVIEW/>
|
||||
<https://invoke-ai.github.io/InvokeAI/contributing/frontend/>
|
||||
|
||||
@@ -58,7 +58,7 @@
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@dagrejs/graphlib": "^2.2.4",
|
||||
"@fontsource-variable/inter": "^5.1.0",
|
||||
"@invoke-ai/ui-library": "^0.0.43",
|
||||
"@invoke-ai/ui-library": "^0.0.44",
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.2.3",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user